[模板] 普通平衡树
https://www.luogu.org/problemnew/show/P3369
#include<cstdio> const int N = 100010; int fa[N],ch[N][2],siz[N],cnt[N],data[N]; int tn,root; #define gc getchar() inline int read(){ int x = 0; char c = gc; while(c < '0' || c > '9') c = gc; while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = gc; return x; } inline int son(int x) { return x == ch[fa[x]][1]; } inline void pushup(int x) { siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x]; } void rotate(int x) { int y = fa[x], z = fa[y], b = son(x), c = son(y), a = ch[x][!b]; if(z) ch[z][c] = x; else root = x; fa[x] = z; if(a) fa[a] = y; ch[y][b] = a; ch[x][!b] = y; fa[y] = x; pushup(y); pushup(x); } inline void splay(int x,int rt) { while (fa[x] != rt) { int y = fa[x],z = fa[y]; if (z==rt) rotate(x); else { if (son(x)==son(y)) rotate(y),rotate(x); else rotate(x),rotate(x); } } } inline int getpre(int x) { int p = ch[root][0]; while (ch[p][1]) p = ch[p][1]; return p; } inline int getsuc(int x) { int p = ch[root][1]; while (ch[p][0]) p = ch[p][0]; return p; } int getk(int rt, int k) { if(data[rt] == k) { splay(rt, 0); return siz[ch[rt][0]] + 1; } if(k < data[rt]) return getk(ch[rt][0], k); else return getk(ch[rt][1], k); } int getkth(int rt, int k) { int l = ch[rt][0]; if(siz[l] < k && siz[l] + cnt[rt] >= k) return data[rt]; else if(siz[l] >= k) return getkth(ch[rt][0], k); else return getkth(ch[rt][1], k - siz[l] - cnt[rt]); } inline void Insert(int x) { // 插入 if (root==0) { ++tn; root = tn; ch[tn][1] = ch[tn][0] = fa[tn] = 0; siz[tn] = cnt[tn] = 1; data[tn] = x; return; } int p = root,pa = 0; while (true) { if (x==data[p]) { cnt[p]++; pushup(p); pushup(pa); splay(p,0); break;} pa = p; p = ch[p][x > data[p]]; if (p==0) { tn++; ch[tn][1] = ch[tn][0] = 0; siz[tn] = cnt[tn] = 1; fa[tn] = pa; ch[pa][x > data[pa]] = tn; data[tn] = x; pushup(pa),splay(tn,0); break; } } } inline void Clear(int x) { ch[x][0] = ch[x][1] = fa[x] = siz[x] = cnt[x] = data[x] = 0; } inline void Delete(int x) { // 删除 getk(root, x); if (cnt[root] > 1) { cnt[root]--; siz[root] --; return; } if (!ch[root][0] && !ch[root][1]) { Clear(root); root = 0; return; } if (!ch[root][0]) { int tmp = root; root = ch[root][1]; fa[root] = 0; Clear(tmp); return; } else if (!ch[root][1]) { int tmp = root; root = ch[root][0]; fa[root] = 0; Clear(tmp); return; } int tmp = root,pre = ch[root][0]; while (ch[pre][1]) pre = ch[pre][1]; splay(pre,0); ch[root][1] = ch[tmp][1]; fa[ch[tmp][1]] = root; Clear(tmp); pushup(root); } int main() { int n = read(); while (n--) { int opt = read(),x = read(); if (opt==1) Insert(x); else if (opt==2) Delete(x); else if (opt==3) printf("%d\n",getk(root, x)); else if (opt==4) printf("%d\n",getkth(root, x)); else if (opt==5) Insert(x),printf("%d\n",data[getpre(x)]),Delete(x); else Insert(x),printf("%d\n",data[getsuc(x)]),Delete(x); } return 0; }