抓狂

bzoj 3224 

(splay模板 指针版)

这个模板改了两天TAT

说结果吧,错在kth中,但我把除了zig、splay、kth的部分重写了两遍,一句话一句话注释。。。

结论是:珍爱生命,不要想当然。。。

#include<cstdio>
#include<cstdlib>
#include<iostream>
using namespace std;
typedef long long ll;
const int maxn=1E6+100;
int read(){
    char cc=' ';int re=0,fh=1;while(cc==' '||cc=='\r'||cc=='\n')cc=getchar();
    if(cc=='+')cc=getchar(),fh=1;if(cc=='-')cc=getchar(),fh=-1;
    while('0'<=cc&&cc<='9'){re=re*10+cc-'0';cc=getchar();}return re*fh;
}
int n,m,data[maxn];

int ans;

struct Node{
    Node();
    Node *ch[2],*fa;
    int d,sum;int size;
//    int set,add[2];
//    short vset;
    short pl(){return this == fa->ch[1];}  //是否为父节点右儿子 
    void count(); void push();
//    void mark(ll,ll,short);
}*null;
Node::Node(){
    ch[0] = ch[1] = fa = null;  sum = size = d = 0;
//     vset = add[0] = add[1]=0;
}
/*
void Node::mark(ll val,ll dd,short t){
    if (this==null) return;
    if (!t){            //set
        set = val;
        sum = set*size;
        d = set;
        vset = 1;
        add[0] = add[1] = 0;
    } else{                //add
        add[0]+=val;
        add[1]+=dd;
        sum += val * size;
        sum += dd*(size)*(size-1)/2;
        d += val + dd*(ch[0]->size);
    }
}*/
void Node::count(){
    size = ch[0]->size+ch[1]->size+sum;
//    sum = ch[0]->sum+ch[1]->sum+1;
}
namespace Splay{
    Node *ROOT;
    Node *build(int l=1,int r=n){
        if (l>r) return null;
        int mid = (l+r)/2;
        Node *ro = new Node;
        ro->d = data[mid];
        ro->ch[0]=build(l,mid-1);
        ro->ch[1]=build(mid+1,r);
        ro->ch[0]->fa=ro;
        ro->ch[1]->fa=ro;
        ro->count();
        return ro;
    }
    void Build(){
        null = new Node;
        *null = Node();
        ROOT = new Node;
        ROOT = null;
    }
    void rotate(Node *k){
        Node *r=k->fa; if (r==null||k==null)return;
//        r->push();k->push();
        int x = k->pl()^1;                            //x为另一子节点 
        r->ch[x^1]=k->ch[x];
        r->ch[x^1]->fa=r;
        if (r->fa!=null) r->fa->ch[r->pl()] = k;
        else ROOT=k;
        k->fa = r->fa;r->fa = k;
        k->ch[x] = r;
        r->count();k->count();
    }
    void splay(Node *r,Node *tar=null){
        for(;r->fa!=tar;rotate(r))
        if(r->fa->fa!=tar)rotate(r->pl()==r->fa->pl()?r->fa:r);
//        r->push();
    }
    Node *find(int x){
        Node *r=ROOT;
        while (r!=null){
            if (r->d==x) break;
            int c;    if (r->d>x) c=0; else c=1;
            r=r->ch[c];
        }
        if (r!=null) splay(r);
        return r;
        
        /*while (r!=null){
            if (x<r->d) r=r->ch[0];
            else if (x>r->d) r=r->ch[1];
            else return r;
        }
        return null;*/
    }
    
    void insert(int x){
        Node *r = ROOT;
        if (ROOT == null){
            ROOT = new Node ;
            ROOT->d = x ;
            ROOT->sum=1;
            ROOT->size=1;
            return;
        }
        while (1){
            int c;
            if (r->d==x) {r->sum++;r->size++;splay(r);break;}
            if (r->d>x) c=0; else c=1;
            if (r->ch[c]==null){
                r->ch[c]=new Node;
                r->ch[c]->d=x;
                r->ch[c]->fa=r;
                r->ch[c]->sum=1;
                r->ch[c]->size=1;
                splay(r->ch[c]);
                break;
            }else r=r->ch[c];
        }
        return;
    }
    
    Node *kth(int k){
        Node *r = ROOT;
        while (r!=null){
//            r->push();
            if (r->ch[0]->size>=k) r=r->ch[0];
            else if(r->ch[0]->size+r->sum>=k) return r;
            else k-=r->ch[0]->size+r->sum,r=r->ch[1];
        }
        return null;
    }
    Node *pack(int l,int r){
        Node *ln = kth(l-1),
        *rn = kth(r+1);
        if (ln==null&&rn==null) return ROOT;
        else if (ln==null){
            splay(rn);
            return rn->ch[0];
        }else if (rn==null){
            splay(ln);
            return ln->ch[1];
        }else{
            splay(ln);
            splay(rn,ROOT);
            return rn->ch[0];
        }
    }
    Node *rightdown(Node *r){
        while (r->ch[1]!=null){
            r=r->ch[1];
        }return r;
    }
    Node *leftdown(Node *r){
        while (r->ch[0]!=null){
            r=r->ch[0];
        }return r;
    }
    
    void earse(int x)
    {
        /*Node *r=find(x);
        splay(r);
        if (r->sum>1){
            splay(r);
            r->size--; r->sum--; return;
        }
        else{
            splay(r);
            if ((r->ch[0]==null)&&(r->ch[1]==null)){
                ROOT=null;
                delete r;
            }else if (r->ch[0]==null){
                r->ch[1]->fa=null;
                ROOT=r->ch[1];
                delete r;
            }else if (r->ch[1]==null){
                r->ch[0]->fa=null;
                ROOT=r->ch[0];
                delete r;
            }else{
                splay(rightdown(r->ch[0]),ROOT);
                r->ch[0]->ch[1]=r->ch[1];
                r->ch[1]->fa=r->ch[0];
                r->ch[0]->fa=null;
                r->ch[0]->count();
                ROOT=r->ch[0];
                delete r;
            }
        }*/
        
        Node *sr;
        splay(find(x));
        Node *r=ROOT;
        if (r==null) return;
        if (!(--r->sum)){
            if (r->ch[1]!=null){
                sr=r->ch[1];
                while (sr->ch[0]!=null) sr=sr->ch[0];
                splay(sr);
                sr->ch[0]=r->ch[0];
                r->ch[0]->fa=sr;
                sr->fa=null;
            }
            else {
                sr=r->ch[0];
                sr->fa=null;
                sr->count();
                ROOT=sr;
            }
        }
        else --r->size;
        return;
    }
    void query(int x)
    {
        Node *r=find(x);
        splay(r);
        printf("%d\n",r->ch[0]->size+1);
        return;
    }
    
    int prev(int x)
    {
        Node *r=ROOT;
        int ans=-0x7fffff;
        while (r!=null)
        {
            if (r->d>=x) r=r->ch[0];
            else {ans=max(ans,r->d);r=r->ch[1];}
        }
        return ans;
    }
    int succ(int x)
    {
        Node *r=ROOT;
        int ans=0x7ffffff;
        while (r!=null)
        {
            if (r->d<=x) r=r->ch[1];
            else {ans=min(ans,r->d);r=r->ch[0];}
        }
        return ans;
    }
    
    void dfs(Node *x)
    {
        if (x->ch[0]!=null) dfs(x->ch[0]);
        printf("%d ",x->d);
        if (x->ch[1]!=null) dfs(x->ch[1]);
    }
    void pr()
    {
        Node *re=ROOT;
        dfs(re);
        printf("\n\n");
    }
}

int main()
{
//    freopen("3224.in","r",stdin);
//    freopen("3224.out","w",stdout);
    int i,j,a,b,c;
    n = read();
    Splay::Build();
    b=0;
    for (i=1;i<=n;i++)
    {
        j=read();
        Node *re;
        switch(j){
            case 1:
                a=read();
                Splay::insert(a);
            break;
            case 2:
                a=read();
//                r = Splay::pack(a,b);
                Splay::earse(a);
            break;
            case 3:
                a=read();
                Splay::query(a);
            break;
            case 4:
                a=read();
                re = Splay::kth(a);
                printf("%d\n",re->d);
            break;
            case 5:
                a=read();
                printf("%d\n",Splay::prev(a));
            break;
            case 6:
                a=read();
                printf("%d\n",Splay::succ(a));
            break;
        }
//        Splay::pr();
    }
    return 0;
}

 

posted @ 2016-02-26 21:01  Reterra  阅读(193)  评论(0编辑  收藏  举报