[模板] 普通平衡树

https://www.luogu.org/problemnew/show/P3369

#include<cstdio>

const int N = 100010;
int fa[N],ch[N][2],siz[N],cnt[N],data[N];
int tn,root;

#define gc getchar()

inline int read(){
    int x = 0; char c = gc;
    while(c < '0' || c > '9') c = gc;
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = gc;
    return x;
}

inline int son(int x) {
    return x == ch[fa[x]][1];
}

inline void pushup(int x) {
    siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}

void rotate(int x) {
    int y = fa[x], z = fa[y], b = son(x), c = son(y), a = ch[x][!b];
    if(z) ch[z][c] = x; else root = x; fa[x] = z;
    if(a) fa[a] = y; ch[y][b] = a; ch[x][!b] = y; fa[y] = x;
    pushup(y); pushup(x);
}

inline void splay(int x,int rt) {
    while (fa[x] != rt) {
        int y = fa[x],z = fa[y];
        if (z==rt) rotate(x);
        else {
            if (son(x)==son(y)) rotate(y),rotate(x);
            else rotate(x),rotate(x);
        }
    }
}

inline int getpre(int x) {
    int p = ch[root][0];
    while (ch[p][1]) p = ch[p][1];
    return p;
}

inline int getsuc(int x) {
    int p = ch[root][1];
    while (ch[p][0]) p = ch[p][0];
    return p;
}

int getk(int rt, int k) {
    if(data[rt] == k) {
        splay(rt, 0);
        return siz[ch[rt][0]] + 1;
    }
    if(k < data[rt]) return getk(ch[rt][0], k);
    else return getk(ch[rt][1], k);
}

int getkth(int rt, int k) {
    int l = ch[rt][0];
    if(siz[l] < k && siz[l] + cnt[rt] >= k) return data[rt];
    else if(siz[l] >= k) return getkth(ch[rt][0], k);
    else return getkth(ch[rt][1], k - siz[l] - cnt[rt]);
}

inline void Insert(int x) { // 插入
    if (root==0) { 
        ++tn; root = tn; ch[tn][1] = ch[tn][0] = fa[tn] = 0; siz[tn] = cnt[tn] = 1; data[tn] = x;
        return;
    }
    int p = root,pa = 0;
    while (true) {
        if (x==data[p]) { cnt[p]++; pushup(p); pushup(pa); splay(p,0); break;}
        pa = p;
        p = ch[p][x > data[p]];
        if (p==0) { 
            tn++; ch[tn][1] = ch[tn][0] = 0; siz[tn] = cnt[tn] = 1; fa[tn] = pa;
            ch[pa][x > data[pa]] = tn; data[tn] = x; pushup(pa),splay(tn,0);
            break;
        }
    }
}

inline void Clear(int x) {
    ch[x][0] = ch[x][1] = fa[x] = siz[x] = cnt[x] = data[x] = 0;
}

inline void Delete(int x) { // 删除
    getk(root, x);
    if (cnt[root] > 1) { cnt[root]--; siz[root] --; return; }
    if (!ch[root][0] && !ch[root][1]) { Clear(root); root = 0; return; }
    if (!ch[root][0]) { int tmp = root; root = ch[root][1]; fa[root] = 0; Clear(tmp); return; } 
    else if (!ch[root][1]) { int tmp = root; root = ch[root][0]; fa[root] = 0; Clear(tmp); return; }
    int tmp = root,pre = ch[root][0];
    while (ch[pre][1]) pre = ch[pre][1];
    splay(pre,0);
    ch[root][1] = ch[tmp][1];
    fa[ch[tmp][1]] = root;
    Clear(tmp);
    pushup(root);
}

int main() {
    int n = read();
    while (n--) {
        int opt = read(),x = read();
        if (opt==1) Insert(x);
        else if (opt==2) Delete(x);
        else if (opt==3) printf("%d\n",getk(root, x));
        else if (opt==4) printf("%d\n",getkth(root, x));
        else if (opt==5) Insert(x),printf("%d\n",data[getpre(x)]),Delete(x);
        else Insert(x),printf("%d\n",data[getsuc(x)]),Delete(x);
    }
    return 0;
}

 

posted @ 2017-12-05 19:59  xayata  阅读(174)  评论(0编辑  收藏  举报