「Splay」普通平衡树模板
口诀:
$rotate$:先上再下,最后自己
$splay$:祖父未到旋两次,三点一线旋父亲,三点折线旋自己。
$delete$:没有儿子就删光。单个儿子删自己。两个儿子找前驱。
易错点:
$rotate$:祖父不在自己做根
$delete$:自己做根父亲为0
$kth$:先减排名后转移
/*By DennyQi 2018*/ #include <cstdio> #include <queue> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int MAXN = 100010; const int INF = 0x3f3f3f3f; inline int Max(const int a, const int b){ return (a > b) ? a : b; } inline int Min(const int a, const int b){ return (a < b) ? a : b; } inline int read(){ int x = 0; int w = 1; register char c = getchar(); for(; c ^ '-' && (c < '0' || c > '9'); c = getchar()); if(c == '-') w = -1, c = getchar(); for(; c >= '0' && c <= '9'; c = getchar()) x = (x<<3) + (x<<1) + c - '0'; return x * w; } int n,opt,x,num_node; int ch[MAXN][2],fa[MAXN],val[MAXN],size[MAXN],cnt[MAXN],root; struct Splay{ inline bool rson(int f, int x){ return ch[f][1] == x; } inline void update(int x){ size[x] = size[ch[x][0]] + size[ch[x][1]] + cnt[x]; } inline void clear(int x){ val[x]=cnt[x]=size[x]=fa[x]=ch[x][0]=ch[x][1]=0; } inline void rotate(int x){ int f = fa[x], gf = fa[f]; bool p = rson(f, x), q = !p; if(gf) ch[gf][rson(gf,f)] = x; else root = x; fa[x] = gf; ch[f][p] = ch[x][q], fa[ch[x][q]] = f; ch[x][q] = f, fa[f] = x; update(f), update(x); } inline void splay(int x, int target){ while(fa[x] != target){ int f = fa[x], gf = fa[f]; if(gf == target){ rotate(x); break;} if(rson(gf,f) == rson(f,x)) rotate(f); else rotate(x); rotate(x); } } inline void Insert(int v){ int o = root; if(root == 0){ root = ++num_node; cnt[root] = size[root] = 1; val[root] = v; return; } for(;o;){ if(v == val[o]){ cnt[o]++, size[o]++; splay(o, 0); return; } bool b = v>val[o]; if(!ch[o][b]){ ch[o][b] = ++num_node; cnt[ch[o][b]] = size[ch[o][b]] = 1; val[ch[o][b]] = v, fa[ch[o][b]] = o; splay(ch[o][b], 0); return; } o = ch[o][v>val[o]]; } } inline void Find(int v){ for(int o = root; o; o = ch[o][v>val[o]]){ if(val[o] == v){ splay(o, 0); return; } if(!ch[o][v>val[o]]) return; } } inline void Delete(int v){ Find(v); if(val[root] != v) return; int o = root; if(cnt[o] > 1){ --cnt[o],--size[o]; return; } if(!ch[o][0] && !ch[o][1]){ root = 0, fa[root] = 0; return; } if(!ch[o][0]){ root = ch[o][1], fa[root] = 0; return; } if(!ch[o][1]){ root = ch[o][0], fa[root] = 0; return; } int l_max = ch[root][0]; while(ch[l_max][1]) l_max = ch[l_max][1]; splay(l_max, root); ch[l_max][1] = ch[root][1]; fa[ch[root][1]] = l_max; fa[l_max] = 0; int pre_root = root; root = l_max; clear(pre_root); } inline int Rnk(int x){ Find(x); return size[ch[root][0]] + 1; } inline int Kth(int k){ for(int o = root; o;){ if(size[ch[o][0]] >= k) o = ch[o][0]; else if(size[ch[o][0]] + cnt[o] < k){ k -= size[ch[o][0]] + cnt[o]; o = ch[o][1]; } else{ splay(o,0); return val[o]; } } } inline int Pre(int v){ Insert(v); int o = ch[root][0]; while(ch[o][1]) o = ch[o][1]; int ans = val[o]; Delete(v); return ans; } inline int Nxt(int v){ Insert(v); int o = ch[root][1]; while(ch[o][0]) o = ch[o][0]; int ans = val[o]; Delete(v); return ans; } }qxz; int main(){ // freopen(".in","r",stdin); n = read(); for(int i = 1; i <= n; ++i){ opt = read(), x = read(); if(opt==1) qxz.Insert(x); if(opt==2) qxz.Delete(x); if(opt==3) printf("%d\n",qxz.Rnk(x)); if(opt==4) printf("%d\n",qxz.Kth(x)); if(opt==5) printf("%d\n",qxz.Pre(x)); if(opt==6) printf("%d\n",qxz.Nxt(x)); } return 0; }