[学习笔记]普通平衡树Splay

哈哈哈哈哈哈哈终于会打\(splay\)

现在我来发一下\(splay\)的讲解吧

小蒟蒻由于码风与他人不同,所以自己找了上百篇码风诡异的\(splay\)合成的,感谢\(zcysky\)的代码与我码风相近,让我看懂了

首先,\(splay\)其实就是把一棵二叉搜索树变成一棵深度不会超过\(logn\)的二叉搜索树,它在不断旋转至深度至\(logn\)

当然,那么就少不了\(rotate\)\(splay\)操作

具体看代码好了,说的比较麻烦

#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int ch[maxn][2],fa[maxn],siz[maxn],cnt[maxn],key[maxn];
int sz,rt;

inline void clear(int x){//清空结点
    ch[x][0]=ch[x][1]=fa[x]=siz[x]=cnt[x]=key[x]=0;
}
inline bool get(int x){//判断是左儿子还是右儿子
    return ch[fa[x]][1]==x;
}
inline void update(int x){//更新结点
    if(x){
        siz[x]=cnt[x];
        if(ch[x][0]) siz[x]+=siz[ch[x][0]];
        if(ch[x][1]) siz[x]+=siz[ch[x][1]];
    }
}
inline void rotate(int x){//双旋
    int y=fa[x],z=fa[y],k=get(x);
    ch[y][k]=ch[x][k^1];fa[ch[y][k]]=y;
    ch[x][k^1]=y;fa[y]=x;fa[x]=z;
    if(z) ch[z][ch[z][1]==y]=x;
    update(y);update(x);
}
inline void splay(int x){//伸展
    for(int y;y=fa[x];rotate(x))
        if(fa[y])
            rotate((get(x)==get(y))?y:x);
    rt=x;
}
inline void insert(int val){//插入
    if(rt==0){sz++;ch[sz][0]=ch[sz][1]=fa[sz]=0;rt=sz;siz[sz]=cnt[sz]=1;key[sz]=val;return;}
    int x=rt,y=0;
    while(1){
        if(val==key[x]){cnt[x]++;update(x);update(y);splay(x);return;}
        y=x;x=ch[x][key[x]<val];
        if(x==0){
            sz++;fa[sz]=y;
            ch[y][key[y]<val]=sz;
            ch[sz][0]=ch[sz][1]=0;
            siz[sz]=cnt[sz]=1;
            key[sz]=val;
            update(y);
            splay(sz);
            return;
        }
    }
}
inline int find(int val){//找x的排名
    int x=rt,ans=0;
    while(1){
        if(val<key[x]) x=ch[x][0];
        else {
            ans+=(ch[x][0]?siz[ch[x][0]]:0);
            if(val==key[x]){
                splay(x);return ans+1;
            }
            ans+=cnt[x];
            x=ch[x][1];
        }
    }
}
inline int findkth(int val){//找排名为x的数
    int x=rt,k;
    while(1){
        if(ch[x][0]&&val<=siz[ch[x][0]]) 
            x=ch[x][0];
        else {
            k=(ch[x][0]?siz[ch[x][0]]:0)+cnt[x];
            if(val<=k) return key[x];
            val-=k;x=ch[x][1];
        }
    }
}
inline int pre(){//找前驱
    int x=ch[rt][0];
    while(ch[x][1]) x=ch[x][1];
    return x;
}
inline int nxt(){//找后继
    int x=ch[rt][1];
    while(ch[x][0]) x=ch[x][0];
    return x;
}
inline void del(int val){//删除结点
    find(val);int x=rt;
    if(cnt[rt]>1){cnt[rt]--;update(rt);return;}
    if(!ch[rt][0]&&!ch[rt][1]){clear(rt);rt=0;return;}
    if(!ch[rt][0]){rt=ch[x][1];fa[rt]=0;clear(x);return;}
    else if(!ch[rt][1]){rt=ch[x][0];fa[rt]=0;clear(x);return;}
    splay(pre());
    ch[rt][1]=ch[x][1];
    fa[ch[x][1]]=rt;
    clear(x);update(rt);
}

int main()
{
    int n,opt,x;
    scanf("%d",&n);
    while(n--){
        scanf("%d%d",&opt,&x);
        switch(opt){
            case 1:insert(x);break;
            case 2:del(x);break;
            case 3:printf("%d\n",find(x));break;
            case 4:printf("%d\n",findkth(x));break;
            case 5:insert(x);printf("%d\n",key[pre()]);del(x);break;
            case 6:insert(x);printf("%d\n",key[nxt()]);del(x);break;
        }
    }
    return 0;
}
posted @ 2018-09-28 18:48  Owen_codeisking  阅读(397)  评论(2编辑  收藏  举报