1. 程式人生 > >淺談splay(雙旋)

淺談splay(雙旋)

搜索 roo alt 大小 index swa del using 函數

  • 今天剛剛學習完splay,講一下自己的想法吧

  • 首先splay和treap不一樣,treap通過隨機數來調整樹的形態。但splay不一樣,再每插入或操作一次後,你都會把他旋轉到根,再旋轉過程中樹的形態會不斷改變,這樣來達到均攤效果 常數據大

  • 來看看具體實現吧
    首先定義數組,\(size\) 子樹大小(包括自己),\(fa\) 節點的父親,\(key\) 該節點的權值,\(cnt\) 該節點權值出現次數,$ch $表示兒子 0表左二子,1表右兒子

首先看幾個簡單函數

inline void update(int x)
{
    size[x]=cnt[x]+size[ch[x][0]]+size[ch[x][1
]]; }

更新子樹大小

inline int get(int x){return x==ch[fa[x]][1];}

返回該節點是left兒子還是right兒子

inline void clear(int x){ch[x][0]=ch[x][1]=fa[x]=size[x]=cnt[x]=key[x]=0;}

刪除該節點,清空所有信息

接下來是splay的精髓所在

inline void rotate(int x,int &k)
{
    static int old,oldfa,o;
    old=fa[x];oldfa=fa[old];o=get(x);
    if(old==k)k=x;
    else
ch[oldfa][get(old)]=x; fa[x]=oldfa; ch[old][o]=ch[x][o^1];fa[ch[x][o^1]]=old; ch[x][o^1]=old;fa[old]=x; update(x),update(old); } inline void splay(int x,int &k) { while(x!=k) { if(fa[x]!=k)rotate(get(x)^get(fa[x])?x:fa[x],k); rotate(x,k); } }

rotate,splay,是splay核心操作,顯然splay是依賴於rotate的,讓我們看一下rotate是如何實現的吧
技術分享圖片


(手繪圖)
我們考慮從圖上左往右的過程,我們要將y旋上去,因為y本是x的右兒子,所以x放到y的左兒子,將y的原本左兒子設為x的右兒子,這是左旋,還有對稱操作右旋,但我們不必要打兩個函數,用 ^可以實現左右兒子的轉換,用get操作實現,具體實現參考代碼,打代碼時最好畫個圖參照一下。

splay,這個操作完全依靠rotate,目的就是把你要的節點旋轉到k(一般是root),k要傳地址,要修改。在while循環裏加了個小小的優化,但x和他的fa在同一側時可以旋fa,小小的加速

inline void insert(int x)
{
    if(!root){root=++sz;size[sz]=cnt[sz]=1;key[sz]=x;return;}
    int now=root,o;
    while(1)
    {
        if(x==key[now])
        {
            ++cnt[now];
            splay(now,root);
            update(now);
            return;
        }
        o=x>key[now]?1:0;
        if(!ch[now][o])
        {
            ch[now][o]=++sz;
            size[sz]=cnt[sz]=1;
            key[sz]=x;fa[sz]=now;
            ch[now][o]=sz;
            update(now);
            splay(sz,root);
            return;
        }
        else now=ch[now][o];
    }
}

insert,插入一個數,當沒有數時就直接把這個數設為根,else 因為樹滿足二叉排序樹的性質,所以比當前節點的key小就往左走,否則往右走,直到找到一個空節點,更新信息,由於這個點以上所有的點\(size\)都要加一,不好update,所以把這給點旋轉到根,將這個點update就行了

inline int find_pos(int x)
{
    int now=root;
    while(1)
    {
        if(x==key[now]){return now;}
        if(x<key[now])now=ch[now][0];
        else now=ch[now][1];
    }
}

找到該值在樹中的節點編號

inline int pre()
{
    int now=ch[root][0];
    while(ch[now][1])now=ch[now][1];
    return now;
}
inline int nex()
{
    int now=ch[root][1];
    while(ch[now][0])now=ch[now][0];
    return now;
}

求前驅,後繼,前驅從根的左兒子開始一直往右跑,後繼從根的右兒子開始一直往左跑即可

void del(int x)
{
    splay(find_pos(x),root);
    if(cnt[root]>1){--cnt[root];return;}
    if(!ch[root][0]&&!ch[root][1]){clear(root);root=0;return;}
    if(ch[root][0]&&ch[root][1])
    {
        int oldroot=root;
        splay(pre(),root);
        fa[ch[oldroot][1]]=root;
        ch[root][1]=ch[oldroot][1];
        clear(oldroot);
        update(root);
    }
    else
    {
        int o=ch[root][1]>0;
        root=ch[root][o];
        clear(fa[root]);
        fa[root]=0;
    }
}

刪除操作,有點麻煩,先找到x的位置

  • 如果x有多個就\(cnt\)減一
  • 如果一個兒子都沒有就直接刪掉,root設為0
  • 如果 只有一個兒子就把兒子設為根,刪去這個點
  • 剩下兩個兒子情況,找到根的前驅,把前驅旋到根,這是root只有左兒子,再把原來根的右兒子到root上,這樣原來的root就脫離了樹,再刪掉即可。
inline int find_order_of_key(int x)
{
    int res=0,now=root;
    while(1)
    {
        if(x<key[now])now=ch[now][0];
        else
        {
            res+=size[ch[now][0]];
            if(x==key[now]){splay(now,root);return res+1;}
            res+=cnt[now];
            now=ch[now][1];
        }
    }
}
inline int find_by_order(int x)
{
    int now=root;
    while(1)
    {
        if(x<=size[ch[now][0]])now=ch[now][0];
        else
        {
            int temp=size[ch[now][0]]+cnt[now];
            if(x<=temp)return key[now];
            else{x-=temp;now=ch[now][1];}
        }
    }
}

找x的排名,與找排名為x的數,其實大同小異,用二叉搜索樹的性質即可,只是記得答案不一樣罷了

inline void rever(int l,int r)
{
    l=find(l-1);r=find(r+1);
        splay(l,root);splay(r,ch[l][1]);
        rev[ch[r][0]]^=1;
}

找到區間左邊一個和區間的右邊一個點在樹中位置,把左邊的點旋轉到根,再把右邊的點旋到root的右兒子,這時這段區間一定是ch[r][0]的子樹(想一想,為什麽)(根據二叉搜索樹的性質),把這個點打上標記即可;當遇到有翻轉標記的點時,交換其左右子樹,並下傳標記即可。

ok,splay的基本操作就是這些了

下面是完整代碼
洛谷P3369 treap模板

#include<bits/stdc++.h>
using namespace std;
typedef int sign;
typedef long long ll;
#define For(i,a,b) for(register sign i=(sign)a;i<=(sign)b;++i)
#define Fordown(i,a,b) for(register sign i=(sign)a;i>=(sign)b;--i)
const int N=1e5+5;
void cmax(sign &a,sign b){if(a<b)a=b;}
void cmin(sign &a,sign b){if(a>b)a=b;}
template<typename T>T read()
{
    T ans=0,f=1;
    char ch=getchar();
    while(!isdigit(ch)&&ch!=‘-‘)ch=getchar();
    if(ch==‘-‘)f=-1,ch=getchar();
    while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch-‘0‘),ch=getchar();
    return ans*f;
}
void file()
{
    #ifndef ONLINE_JUDGE
        freopen("splay.in","r",stdin);
        freopen("splay.out","w",stdout);
    #endif
}
int fa[N],size[N],key[N],cnt[N],ch[N][2],sz,root;
inline void update(int x){size[x]=cnt[x]+size[ch[x][0]]+size[ch[x][1]];}
inline int get(int x){return x==ch[fa[x]][1];}
inline void clear(int x){ch[x][0]=ch[x][1]=fa[x]=size[x]=cnt[x]=key[x]=0;}
inline void rotate(int x,int &k)
{
    static int old,oldfa,o;
    old=fa[x];oldfa=fa[old];o=get(x);
    if(old==k)k=x;
    else ch[oldfa][get(old)]=x;
    fa[x]=oldfa;
    ch[old][o]=ch[x][o^1];fa[ch[x][o^1]]=old;
    ch[x][o^1]=old;fa[old]=x;
    update(x),update(old);
}
inline void splay(int x,int &k)
{
    while(x!=k)
    {
        if(fa[x]!=k)rotate(get(x)^get(fa[x])?x:fa[x],k);
        rotate(x,k);
    }
}
inline void insert(int x)
{
    //puts("");
    if(!root){root=++sz;size[sz]=cnt[sz]=1;key[sz]=x;return;}
    int now=root,o;
    while(1)
    {
        if(x==key[now])
        {
            ++cnt[now];
            splay(now,root);
            update(now);
            return;
        }
        o=x>key[now]?1:0;
        if(!ch[now][o])
        {
            ch[now][o]=++sz;
            size[sz]=cnt[sz]=1;
            key[sz]=x;fa[sz]=now;
            ch[now][o]=sz;
            update(now);
            splay(sz,root);
            return;
        }
        else now=ch[now][o];
        //printf("%d %d %d %d\n",now,fa[now],ch[now][0],ch[now][1]);
    }
}
inline int find_pos(int x)
{
    int now=root;
    while(1)
    {
        if(x==key[now]){return now;}
        if(x<key[now])now=ch[now][0];
        else now=ch[now][1];
    }
}
inline int pre()
{
    int now=ch[root][0];
    while(ch[now][1])now=ch[now][1];
    return now;
}
inline int nex()
{
    int now=ch[root][1];
    while(ch[now][0])now=ch[now][0];
    return now;
}
void del(int x)
{
    splay(find_pos(x),root);
    if(cnt[root]>1){--cnt[root];return;}
    if(!ch[root][0]&&!ch[root][1]){clear(root);root=0;return;}
    if(ch[root][0]&&ch[root][1])
    {
        int oldroot=root;
        splay(pre(),root);
        fa[ch[oldroot][1]]=root;
        ch[root][1]=ch[oldroot][1];
        clear(oldroot);
        update(root);
    }
    else
    {
        int o=ch[root][1]>0;
        root=ch[root][o];
        clear(fa[root]);
        fa[root]=0;
    }
}
inline int find_order_of_key(int x)
{
    int res=0,now=root;
    while(1)
    {
        if(x<key[now])now=ch[now][0];
        else
        {
            res+=size[ch[now][0]];
            if(x==key[now]){splay(now,root);return res+1;}
            res+=cnt[now];
            now=ch[now][1];
        }
    }
}
inline int find_by_order(int x)
{
    int now=root;
    while(1)
    {
        if(x<=size[ch[now][0]])now=ch[now][0];
        else
        {
            int temp=size[ch[now][0]]+cnt[now];
            if(x<=temp)return key[now];
            else{x-=temp;now=ch[now][1];}
        }
    }
}
void input()
{
    int T=read<int>();
    int opt,x;
    while(T--)
    {
        opt=read<int>();x=read<int>();
        if(opt==1)insert(x);
        else if(opt==2)del(x);
        else if(opt==3)printf("%d\n",find_order_of_key(x));
        else if(opt==4)printf("%d\n",find_by_order(x));
        else if(opt==5)
        {
            insert(x);
            printf("%d\n",key[pre()]);
            del(x);
        }
        else if(opt==6)
        {
            insert(x);
            printf("%d\n",key[nex()]);
            del(x);
        }
    }
}
int main()
{
    file();
    input();
    return 0;
}

洛谷P3391 splay模板

#include<bits/stdc++.h>
using namespace std;
typedef int sign;
typedef long long ll;
#define For(i,a,b) for(register sign i=(sign)a;i<=(sign)b;++i)
#define Fordown(i,a,b) for(register sign i=(sign)a;i>=(sign)b;--i)
const int N=1e5+5;
bool cmax(sign &a,sign b){return (a<b)?a=b,1:0;}
bool cmin(sign &a,sign b){return (a>b)?a=b,1:0;}
template<typename T>T read()
{
    T ans=0,f=1;
    char ch=getchar();
    while(!isdigit(ch)&&ch!=‘-‘)ch=getchar();
    if(ch==‘-‘)f=-1,ch=getchar();
    while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch-‘0‘),ch=getchar();
    return ans*f;
}
void file()
{
    #ifndef ONLINE_JUDGE
        freopen("splay.in","r",stdin);
        freopen("splay.out","w",stdout);
    #endif
}
int ch[N][2],fa[N],size[N],rev[N],root,sz;
inline int get(int x){return x==ch[fa[x]][1];}
inline void update(int x){size[x]=1+size[ch[x][0]]+size[ch[x][1]];}
inline void rotate(int x,int &k)
{
    int old=fa[x],oldfa=fa[old],o=get(x);
    if(k==old)k=x;
    else ch[oldfa][ch[oldfa][1]==old]=x;
    fa[x]=oldfa;fa[old]=x;fa[ch[x][o^1]]=old;
    ch[old][o]=ch[x][o^1];ch[x][o^1]=old;
    update(x),update(old);
}
inline void splay(int x,int &k)
{
    while(x!=k)
    {
        if(fa[x]!=k)rotate(get(x)^get(fa[x])?x:fa[x],k);
        //printf("%d %d\n",x,k);
        rotate(x,k);
    }
}
#define mid ((l+r)>>1)
inline void build(int l,int r,int pre)
{
    if(l>r)return;
    ch[pre][mid>=pre]=mid;  
    fa[mid]=pre;size[mid]=1;
    if(l==r)return;
    build(l,mid-1,mid);build(mid+1,r,mid);
    update(mid);
}
#undef mid  
int n,m;
void input(){n=read<int>();m=read<int>();}
inline void rever(int x)
{
    swap(ch[x][0],ch[x][1]);
    rev[ch[x][0]]^=1;rev[ch[x][1]]^=1;
    rev[x]=0;
}
int find(int x)
{
    int now=root;
    while(1)
    {
        if(rev[now])rever(now);
        if(size[ch[now][0]]>=x)now=ch[now][0];
        else 
        {
            if(size[ch[now][0]]==x-1)return now;
            x=x-size[ch[now][0]]-1;
            now=ch[now][1];
        }
    }
}
void work()
{
    int l,r;
    root=(n+3)>>1;
    build(1,n+2,root);
    fa[root]=0;
    while(m--)
    {
        l=read<int>();r=read<int>();
        l=find(l);r=find(r+2);
        splay(l,root);splay(r,ch[l][1]);
        rev[ch[r][0]]^=1;
    }
}
void out(int x)
{
    if(rev[x])rever(x);
    if(ch[x][0])out(ch[x][0]);
    if(x>1&&x<n+2)printf("%d ",x-1);
    if(ch[x][1])out(ch[x][1]);
}
int main()
{
    file();
    input();
    work();
    out(root);
    return 0;
}

另外推薦一篇寫得好的博客

淺談splay(雙旋)