Splay Tree模板

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
using namespace std;

#define RG register int
#define LL long long

template<typename elemType>
inline void Read(elemType &T){
    elemType X=0,w=0; char ch=0;
    while(!isdigit(ch)) {w|=ch=='-';ch=getchar();}
    while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    T=(w?-X:X);
}

namespace Splay{
    template<typename elemType>
    struct SplayTreeNode{
        SplayTreeNode():fa(0),size(0),rev(0){son[0]=son[1]=0;}
        elemType value;
        int son[2],fa,size,rev;
        int &operator[](int x){return son[x];}//重载[],方便取左右儿子
    };

    template<typename elemType>
    struct SplayTree{
        SplayTreeNode<elemType> T[100010];
        int root,cnt,node_num;

        SplayTree():root(0),cnt(0),node_num(0) {}
        int size(){return node_num;}
        bool empty(){return node_num==0;}
        void push_up(int x){//上传标记
            if(x){ T[x].size=T[T[x][0]].size+T[T[x][1]].size+1;}
        }
        void push_down(int x){//下放标记
            if(!x || !T[x].rev) return;
            T[T[x][0]].rev^=1;T[T[x][1]].rev^=1;
            swap(T[x][0],T[x][1]);
            T[x].rev=0;
        }
        void rotate(int x){
            push_down(x);
            int y=T[x].fa;
            if(y==0) return;
            int a=(T[y][0]==x)?1:0;//a=0:左旋,a=1:右旋
            T[x].fa=T[y].fa;
            if(T[T[x].fa][0]==y) T[T[x].fa][0]=x;
            else T[T[x].fa][1]=x;
            T[T[x][a]].fa=y;
            T[y][1^a]=T[x][a];
            T[y].fa=x;T[x][a]=y;
            push_up(y);push_up(x);
        }
        void splay(int x,int s){//把x旋转到s
            if(x==0) return;
            int sf=T[s].fa,y,z;
            while(T[x].fa!=sf){
                y=T[x].fa;z=T[y].fa;
                if(z==sf) rotate(x);
                else{
                    if(T[y][0]==x && T[z][0]==y){rotate(y);rotate(x);}
                    else if(T[y][1]==x && T[z][1]==y){rotate(y);rotate(x);}
                    else if(T[y][0]==x && T[z][1]==y){rotate(x);rotate(x);}
                    else{rotate(x);rotate(x);}
                }
            }
            if(s==root) root=x;
            return;
        }
        void insert_node(int x,int u){
            if(x==0) return;//error
            if(T[u].value<T[x].value){
                if(T[x][0]==0){T[x][0]=u;T[u].fa=x;return;}
                else insert_node(T[x][0],u);
            }else{
                if(T[x][1]==0){T[x][1]=u;T[u].fa=x;return;}
                else insert_node(T[x][1],u);
            }
            push_up(x);
        }
        void insert(elemType val){//插入val
            T[++cnt].value=val;T[cnt].size=1;++node_num;
            if(root==0){root=cnt;return;}
            else{insert_node(root,cnt);splay(cnt,root);}
        }
        int get_previous(int x,elemType val){
            if(x==0) return 0;
            push_down(x);
            if(T[x].value<val){
                if(T[x][1]==0) return x;
                int temp=get_previous(T[x][1],val);
                return (temp==0)?x:temp;
            }else{
                if(T[x][0]==0) return 0;//not find
                return get_previous(T[x][0],val);
            }
            return 0;//error
        }
        int get_succeed(int x,elemType val){
            if(x==0) return 0;
            push_down(x);
            if(T[x].value>val){
                if(T[x][0]==0) return x;
                int temp=get_succeed(T[x][0],val);
                return (temp==0)?x:temp;
            }else{
                if(T[x][1]==0) return 0;//not find
                return get_succeed(T[x][1],val);
            }
            return 0;//error
        }
        int get_kth(int x,int Kth){
            if(x==0||Kth>T[T[x][0]].size+T[T[x][1]].size+1) return 0;//not find
            push_down(x);
            if(Kth==T[T[x][0]].size+1) return x;
            else if(Kth<=T[T[x][0]].size) return get_kth(T[x][0],Kth);
            return get_kth(T[x][1],Kth-T[T[x][0]].size-1);
        }
        int get_rank(int x,elemType val){
            if(x==0) return 1;
            push_down(x);
            if(T[x].value<val) return T[T[x][0]].size+1+get_rank(T[x][1],val);
            else return get_rank(T[x][0],val);
            return 0;//error
        }
        int find(int x,elemType val){
            if(x==0) return 0;//not find
            if(T[x].value==val) return x;
            else if(T[x].value<val) return find(T[x][1],val);
            else return find(T[x][0],val);
            return 0;//error
        }
        void traversal(int x){//中序遍历SplayTree
            if(x==0) return;
            traversal(T[x][0]);
            cout<<T[x].value<<" ";
            traversal(T[x][1]);
        }
        void delete_node(elemType val){
            int u=find(root,val);
            if(u==0) return;//not find
            splay(u,root);
            int x=T[u][0],y=get_kth(T[u][1],1);
            if(y==0){T[x].fa=0;root=x;return;}
            splay(y,T[u][1]);
            T[y][0]=x;T[y].fa=0;root=y;
            if(x!=0)T[x].fa=y;
            push_up(y);
        }
        int rank(elemType val){int res=get_rank(root,val);splay(find(root,val),root);return res;}
        int kth(int k){int u=get_kth(root,k);splay(u,root);return u;}
        int previous(elemType val){int u=get_previous(root,val);splay(u,root);return u;}
        int succeed(elemType val){int u=get_succeed(root,val);splay(u,root);return u;}
    };
};

Splay::SplayTree<int> Tree;
int Data[100];
int N;

int main(){
    Read(N);
    for(RG i=1;i<=N;++i){
        int opt,num;
        Read(opt);Read(num);
        if(opt==1) Tree.insert(num);
        else if(opt==2) Tree.delete_node(num);
        else if(opt==3) printf("%d\n",Tree.rank(num));
        else if(opt==4) printf("%d\n",Tree.T[Tree.kth(num)].value);
        else if(opt==5) printf("%d\n",Tree.T[Tree.previous(num)].value);
        else if(opt==6) printf("%d\n",Tree.T[Tree.succeed(num)].value);
    }
    return 0;
}
posted @ 2020-05-10 12:09  AE酱  阅读(136)  评论(0编辑  收藏  举报