【Splay】【Luogu3369】普通平衡树
#include<bits/stdc++.h> #define rt T[0].ch[1] using namespace std; const int maxn=1e5+4,INF=0x7f7f7f7f; struct node{ int val,fa,ch[2],rec,sz; //权值,父亲节点,儿子,这个权值的节点出现的次数,size }T[maxn]; int tot=0; inline void update(int x){T[x].sz=T[T[x].ch[0]].sz+T[T[x].ch[1]].sz+T[x].rec;} inline int ident(int x){return T[T[x].fa].ch[0]==x?0:1;} inline void connect(int x,int fa,int how){T[fa].ch[how]=x;T[x].fa=fa;} void rotate(int x){ int y=T[x].fa,z=T[y].fa; int yson=ident(x),zson=ident(y); connect(T[x].ch[yson^1],y,yson); connect(y,x,yson^1); connect(x,z,zson); update(y);update(x); } void splay(int x,int to){ to=T[to].fa; while(T[x].fa!=to){ int y=T[x].fa; if(T[y].fa==to) rotate(x); else if(ident(x)==ident(y)) rotate(y),rotate(x); else rotate(x),rotate(x); } } inline int newpoint(int v,int fa){ T[++tot].fa=fa; T[tot].val=v; T[tot].sz=T[tot].rec=1; return tot; } inline void insert(int x){ int now=rt; if(rt==0) {newpoint(x,0);rt=tot;} else{ while(1){ T[now].sz++; if(T[now].val==x) {T[now].rec++;splay(now,rt);return;} int nxt=x<T[now].val?0:1; if(!T[now].ch[nxt]){ int p=newpoint(x,now); T[now].ch[nxt]=p; splay(p,rt);return; } now=T[now].ch[nxt]; } } } inline int find(int x){ int now=rt; while(1){ if(!now) return 0; if(T[now].val==x) {splay(now,rt);return now;} int nxt=x<T[now].val?0:1; now=T[now].ch[nxt]; } } inline void del(int x){ int pos=find(x); if(!pos) return; if(T[pos].rec>1) {T[pos].rec--;T[pos].sz--;return;} else{ if(!T[pos].ch[0]&&!T[pos].ch[1]) {rt=0;return;} else if(!T[pos].ch[0]){rt=T[pos].ch[1];T[rt].fa=0;return;} else{ int left=T[pos].ch[0]; while(T[left].ch[1]) left=T[left].ch[1]; splay(left,T[pos].ch[0]); connect(T[pos].ch[1],left,1); connect(left,0,1); update(left); } } } inline int rak(int val){ int pos=find(val); return T[T[pos].ch[0]].sz+1; } inline int kth(int x){ int now=rt; while(1){ int u=T[now].sz-T[T[now].ch[1]].sz; if(T[T[now].ch[0]].sz<x&&x<=u) {splay(now,rt);return T[now].val;} if(x<u) now=T[now].ch[0]; else now=T[now].ch[1],x-=u; } } inline int lower(int x){ int now=rt,ans=-INF; while(now){ if(T[now].val<x) ans=max(ans,T[now].val); int nxt=x<=T[now].val?0:1; now=T[now].ch[nxt]; } return ans; } inline int upper(int x){ int now=rt,ans=INF; while(now){ if(T[now].val>x) ans=min(ans,T[now].val); int nxt=x<T[now].val?0:1; now=T[now].ch[nxt]; } return ans; } int main(){ int T;scanf("%d",&T); while(T--){ int opt,x;scanf("%d%d",&opt,&x); if(opt==1) insert(x); else if(opt==2) del(x); else if(opt==3) printf("%d\n",rak(x)); else if(opt==4) printf("%d\n",kth(x)); else if(opt==5) printf("%d\n",lower(x)); else if(opt==6) printf("%d\n",upper(x)); } return 0; }