[模板] 普通平衡树

传送门

您需要一种数据结构:

  • 插入一个数\(x\)
  • 删除一个数\(x\)
  • 查询\(x\)这个数在所有数中的排名
  • 查询排名为\(x\)的数
  • \(x\)这个数的前驱(前驱定义为小于\(x\)的最大数)‘
  • \(x\)这个数的后继(后继定义为大于\(x\)的最小数)

平衡树

Splay

#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 1000005
#define lson ch[rt][0]
#define rson ch[rt][1]
#define INF 2147483647

struct IO {
    char ibuf[1 << 25],*s;
    IO() {
        fread(s = ibuf,1,1<<25,stdin);
    }
    inline int read() {
        int num = 0,f = 1;
        while(*s<'0'||*s>'9') if((*s++)=='-') f = -1;
        while( *s>='0' && *s<='9' ) num = num*10 + (*s++) - 48;
        return num*f;
    }
}ip;
#define read ip.read

int size[MAXN],fa[MAXN],ch[MAXN][2],count[MAXN],val[MAXN];
int root,tot;
int N,opt,x;

inline void maintain(int rt) {
    size[rt] = size[lson] + size[rson] + count[rt];
}

inline void rotate(int rt) {
    int y = fa[rt]; int z = fa[fa[rt]];
    bool pos = (ch[y][1]==rt); int son = ch[rt][pos^1];
    fa[rt] = z; if(z) ch[z][ch[z][1]==y] = rt;
    ch[y][pos] = son; if(son) fa[son] = y;
    fa[y] = rt; ch[rt][pos^1] = y;
    maintain(y); maintain(rt);
}

inline void Splay(int rt,int goal) {
    for(int y=fa[rt],z=fa[fa[rt]];y!=goal;y=fa[rt],z=fa[fa[rt]]) {
        if(z!=goal) ((ch[y][1]==rt)==(ch[z][1]==y)) ? rotate(y) : rotate(rt);
        rotate(rt);
    }
    root = rt;
}

inline void insert(int num) {
    int rt = root,y = 0;
    while(rt&&val[rt]!=num) y = rt, rt = ch[rt][num>val[rt]];
    if(rt) count[rt] ++;
    else {
        val[rt = ++tot] = num; size[rt] = count[rt] = 1;
        ch[rt][0] = ch[rt][1] = 0; fa[rt] = y;
        if(y) ch[y][num>val[y]] = rt;
    }
    Splay(rt,0);
}

inline void find(int num) {
    int rt = root; if(!rt) return;
    while(val[rt]!=num&&ch[rt][num>val[rt]]) rt = ch[rt][num>val[rt]];
    Splay(rt,0);
}

inline int Next(int num,bool flag) {
    find(num); if((val[root]>num&&flag)||(val[root]<num&&!flag)) return root;
    int rt = ch[root][flag]; while(ch[rt][flag^1]) rt = ch[rt][flag^1];
    return rt;
}

inline int Kth(int k) {
    int rt = root;
    while(true) {
        if(k<=size[lson]) rt = lson;
        else if(k>size[lson]+count[rt]) k -= size[lson]+count[rt], rt = rson;
        else return rt;
    }
}

inline void Del(int num) {
    int _last = Next(num,0); int _next = Next(num,1);
    Splay(_last,0); Splay(_next,_last);
    int del = ch[_next][0];
    if(count[del]>1) count[del] --,Splay(del,0);
    else ch[_next][0] = 0,Splay(_next,0);
}

int main() {

    N = read(); insert(INF); insert(-INF);
    for(int i=1;i<=N;++i) {
        opt = read(); x = read();
        if(opt==1) insert(x);
        else if(opt==2) Del(x);
        else if(opt==3) {find(x);printf("%d\n",size[ch[root][0]]);}//需要把-INF扣掉
        else if(opt==4) printf("%d\n",val[Kth(x+1)]);//-INF
        else if(opt==5) printf("%d\n",val[Next(x,0)]);
        else printf("%d\n",val[Next(x,1)]);
    }

    return 0;
}

线段树

  • 不会
  • 先离散化一下,因为数字太大了
  • 线段树维护区间内有多少个数
  • 离散化的新编号就是线段树对应区间的位置

插入ins()

  • 计算\(x\)离散化后的对应编号\(num\)
  • 单点修改区间\([num,num]\),+1即可

删除del()

  • 计算\(x\)离散化后的对应编号\(num\)
  • 单点修改区间\([num,num]\),-1即可

查询排名clac_rk()

  • 计算\(x\)离散化后的对应编号\(num\)
  • 查询区间\([1,num]\)的和,+1即是答案,因为\(x\)前面有若干个数,若干个数+1就\(x\)的排名

查询对应排名的数find_rk()

  • 二分查找,判断\([1,mid]\)的和与给定排名的大小关系。

查前驱

  • find_rk(calc_rk(x)-1)

查后继

  • find_rk(calc_rk(x)+N+1),其中\(N\)\(x\)的数量

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 100005
#define lson (rt<<1)
#define rson (rt<<1|1)

int seg[MAXN<<2];
inline void pushup(int rt) {
    seg[rt] = seg[lson] + seg[rson];
}

void update(int C,int L,int rt,int l,int r) {
    if(l==r&&l==L) {
        seg[rt] += C;
        return;
    }
    int mid = (l+r)>>1;
    if(L<=mid) update(C,L,lson,l,mid);
    else update(C,L,rson,mid+1,r);
    pushup(rt);
}

int query(int L,int R,int rt,int l,int r) {
    if(L>R) return 0;
    if(L<=l&&R>=r) return seg[rt];
    if(L>r||R<l) return 0;
    int mid = (l+r)>>1,ans = 0;
    if(L<=mid) ans += query(L,R,lson,l,mid);
    if(R>mid) ans += query(L,R,rson,mid+1,r);
    return ans;
}

struct Discretization {
    int temp[MAXN],num[MAXN],tot,cnt;
    void reset() {
        tot = 0;
        cnt = 0;
    }
    int get_num(int u) {
        int l = 1, r = tot;
        while(l<r) {
            int mid = (l+r+1)>>1;
            if(num[mid]>u) r = mid - 1;
            else l = mid;
        }
        return l;
    }
    int find(int rk) {
        int l = 1,r = tot;
        while(l<r) {
            int mid = (l+r)>>1;
            int x = query(1,mid,1,1,tot);
            if(x>=rk) r = mid;
            else l = mid + 1;
        }
        return num[l];
    }
    void add(int u) {
        temp[++cnt] = u;
    }
    void unique() {
        for(int i=1;i<=cnt;++i) {
            if(temp[i]!=temp[i+1]) num[++tot] = temp[i];
        }
    }
}D;

struct Opt {
    int opt,num;
}G[MAXN];

int N;

int main() {

    D.reset();

    scanf("%d",&N);
    for(int i=1;i<=N;++i) {
        scanf("%d%d",&G[i].opt,&G[i].num);
        if(G[i].opt!=4) D.add(G[i].num);
    }

    std::sort(D.temp+1,D.temp+1+D.cnt);
    D.unique();

    std::memset(seg,0,sizeof(seg));
    for(int i=1;i<=N;++i) {
        if(G[i].opt==1) update(1,D.get_num(G[i].num),1,1,D.tot);
        else if(G[i].opt==2) update(-1,D.get_num(G[i].num),1,1,D.tot);
        else if(G[i].opt==3) printf("%d\n",query(1,D.get_num(G[i].num)-1,1,1,D.tot) + 1);
        else if(G[i].opt==4) printf("%d\n",D.find(G[i].num));
        else if(G[i].opt==5) {
            int rk = query(1,D.get_num(G[i].num)-1,1,1,D.tot);
            printf("%d\n",D.find(rk));
        }
        else if(G[i].opt==6) {
            int x = D.get_num(G[i].num);
            int rk = query(1,x-1,1,1,D.tot) + query(x,x,1,1,D.tot) + 1;
            printf("%d\n",D.find(rk));
        }
    }

    return 0;
}

树状数组

代码思路和线段树差不多,除了给定排名查数有个骚操作,先贴代码。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 100005
#define lowbit(x) (x&(-x))

struct disc {
    int temp[MAXN],a[MAXN],cnt,tot;
    disc() : cnt(0),tot(0) {}
    void add(int x) {
        temp[++tot] = x;
    }
    void unique() {
        std::sort(temp+1,temp+1+tot);
        for(int i=1;i<=tot;++i) 
            if(temp[i]!=temp[i-1]) a[++cnt] = temp[i];
    }
    int number(int x) {
        int l = 1,r = cnt;
        while(l<r) {
            int mid = (l+r+1)>>1;
            if(a[mid]>x) r = mid - 1;
            else l = mid;
        }
        return l;
    }
}d;

struct q {
    int opt,x;
}Q[MAXN];

int C[MAXN];

inline int query(int x) {
    int ans = 0;
    for(;x>=1;x-=lowbit(x)) ans += C[x];
    return ans;
}

inline void update(int x,int u) {
    for(;x<=d.cnt;x+=lowbit(x)) C[x] += u;
}

inline int work(int rk) {
    int ans = 0,count = 0;
    for(int i=1<<17;i>=1;i>>=1) {
        ans += i;
        if(ans>d.cnt||C[ans]+count>=rk) ans -= i;
        else count += C[ans];
    }
    return ans + 1;
}

int N;

int main() {

    scanf("%d",&N);
    for(int i=1;i<=N;++i) {
        scanf("%d%d",&Q[i].opt,&Q[i].x);
        if(Q[i].opt!=4) d.add(Q[i].x);
    }
    d.unique();

    for(int i=1;i<=N;++i) {
        if(Q[i].opt==1) update(d.number(Q[i].x),1);
        else if(Q[i].opt==2) update(d.number(Q[i].x),-1);
        else if(Q[i].opt==3) printf("%d\n",query(d.number(Q[i].x)-1) + 1);
        else if(Q[i].opt==4) printf("%d\n",d.a[work(Q[i].x)]);
        else if(Q[i].opt==5) printf("%d\n",d.a[work(query(d.number(Q[i].x)-1))]);
        else printf("%d\n",d.a[work(query(d.number(Q[i].x))+1)]);
    }

    return 0;
}
posted @ 2018-11-28 19:39  Neworld1111  阅读(318)  评论(0编辑  收藏  举报