【伸展树】[CQBZOJ2803]普通平衡树splay模板

贴代码

#include<cstdio>
#include<algorithm>
using namespace std;
#define MAXN 500000
int n,mi,ans;
struct node{
    int val,cnt,size;
    node *fa,*ch[2];
}splay_tree[MAXN+10],*tcnt=splay_tree,*root;
void Read(int &x){
    char c;
    bool f=0;
    while(c=getchar(),c!=EOF){
        if(c=='-')
            f=1;
        if(c>='0'&&c<='9'){
            x=c-'0';
            while(c=getchar(),c>='0'&&c<='9')
                x=x*10+c-'0';
            ungetc(c,stdin);
            if(f)
                x=-x;
            return;
        }
    }
}
inline int Get_size(node *p){
    return p?p->size:0;
}
inline void update(node *p){
    p->size=p->cnt+Get_size(p->ch[0])+Get_size(p->ch[1]);
}
void Rotate(node *a,int d){
    node *b=a->fa;
    b->ch[!d]=a->ch[d];
    a->fa=b->fa;
    if(a->ch[d])
        a->ch[d]->fa=b;
    if(b->fa)
        b->fa->ch[b==b->fa->ch[1]]=a;
    b->fa=a;
    a->ch[d]=b;
    update(b);
}
void splay(node *x,node *rt){
    node* y,*z;
    while(x->fa!=rt){
        y=x->fa;
        z=y->fa;
        if(z==rt){
            if(x==y->ch[0])
                Rotate(x,1);
            else
                Rotate(x,0);
        }
        else{
            if(y==z->ch[0])
                if(x==y->ch[0]){    //ZIG-ZIG
                    Rotate(y,1);
                    Rotate(x,1);
                }
                else{                    //ZAG-ZIG
                    Rotate(x,0);
                    Rotate(x,1);
                }
            else{
                if(x==y->ch[1]){  //ZAG-ZAG
                    Rotate(y,0);
                    Rotate(x,0);
                }
                else{                    //ZIG-ZAG
                    Rotate(x,1);
                    Rotate(x,0);
                }
            }
        }
    }
    update(x);
    if(!rt)
        root=x;
}
inline void clear(node *p){
    p->ch[0]=p->ch[1]=NULL;
}
void insert(int val){
    node *p=root,*fa=NULL;
    while(p){
        fa=p;
        if(p->val<val)
            p=p->ch[1];
        else if(p->val>val)
            p=p->ch[0];
        else{
            p->cnt++;
            splay(p,NULL);
            return;
        }
    }
    p=++tcnt;
    clear(p);
    p->cnt=1;
    p->fa=fa;
    if(fa){
        if(val<fa->val)
            fa->ch[0]=p;
        else
            fa->ch[1]=p;
    }
    else
        root=p;
    p->val=val;
    splay(p,NULL);
}
node *find(node *x,int d){
    while(x&&x->ch[d])
        x=x->ch[d];
    return x;
}
void del(int val){
    node *p=root;
    while(p){
        if(val<p->val)
            p=p->ch[0];
        else if(val>p->val)
            p=p->ch[1];
        else{
            if(p->cnt>1){
                p->cnt--;
                splay(p,NULL);
                return;
            }
            else{
                splay(p,NULL);
                node *ft=find(p->ch[0],1),*bk=find(p->ch[1],0);
                if(!ft&&!bk)
                    root=NULL;
                else if(!ft)
                    root=root->ch[1],root->fa=NULL;
                else if(!bk)
                    root=root->ch[0],root->fa=NULL;
                else{
                    splay(ft,NULL);
                    splay(bk,root);
                    bk->ch[0]=NULL;
                    bk->size--;
                    ft->size--;
                }
            }
            return;
        }
    }
}
int find_pos(int val){
    node *p=root;
    int ret=0;
    while(p){
        if(val<p->val)
            p=p->ch[0];
        else if(val>p->val){
            ret+=Get_size(p->ch[0])+p->cnt;
            p=p->ch[1];
        }
        else
            return ret+Get_size(p->ch[0])+1;
    }
    splay(p,NULL);
}
int pos_find(int pos){
    node *p=root;
    while(p){
        if(!pos){
            p=find(p,0);
            return p->val;
        }
        if(pos<Get_size(p->ch[0]))
            p=p->ch[0];
        else if(pos>=Get_size(p->ch[0])+p->cnt)
            pos-=Get_size(p->ch[0])+p->cnt,p=p->ch[1];
        else
            return p->val;
    }
    splay(p,NULL);
}
int find_bk(int val){
    insert(val);
    int ret=find(root->ch[1],0)->val;
    del(val);
    return ret;
}
int find_ft(int val){
    insert(val);
    int ret=find(root->ch[0],1)->val;
    del(val);
    return ret;
}
int main()
{
    Read(n);
    int a,b;
    while(n--){
        Read(a),Read(b);
        if(a==1)
            insert(b);
        else if(a==2)
            del(b);
        else if(a==3)
            printf("%d\n",find_pos(b));
        else if(a==4)
            printf("%d\n",pos_find(b-1));
        else if(a==5)
            printf("%d\n",find_ft(b));
        else
            printf("%d\n",find_bk(b));
    }

}
posted @ 2015-11-21 10:21  outer_form  阅读(156)  评论(0编辑  收藏  举报