[luogu3369]普通平衡树(treap模板)
解题关键:treap模板保存。
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<iostream> #include<cmath> #include<ctime> #define inf 2e9 using namespace std; const int N=1e6+10; struct tree{ int l,r;//左右儿子节点编号 int val;//当前节点的数字 int size;//以当前节点为根的子树的节点数 int cnt;//当前节点的数字的数量 int rnd;//随机优先级 }tr[N];//下标为节点编号 int n,rt,ncnt; int new_node(int x){ ++ncnt;tr[ncnt].val=x;tr[ncnt].size=tr[ncnt].cnt=1;tr[ncnt].rnd=rand();return ncnt; } void pushup(int &k){ int &l=tr[k].l,&r=tr[k].r; tr[k].size=tr[l].size+tr[r].size+tr[k].cnt; } void lturn(int &k){//右孩子左旋,左孩子右旋,核心操作 int t=tr[k].r;tr[k].r=tr[t].l;tr[t].l=k; tr[t].size=tr[k].size;pushup(k);k=t; } void rturn(int &k){ int t=tr[k].l;tr[k].l=tr[t].r;tr[t].r=k; tr[t].size=tr[k].size;pushup(k);k=t; } void insert(int &k,int x){ if(!k){ k=new_node(x); } tr[k].size++; int &l=tr[k].l,&r=tr[k].r; if(x<tr[k].val){ insert(l,x); if(tr[l].rnd<tr[k].rnd) rturn(k); } else if(x>tr[k].val){ insert(r,x); if(tr[r].rnd<tr[k].rnd) lturn(k); } else{ tr[k].cnt++;return; } } void del(int &k,int x){ if(!k) return; int &l=tr[k].l,&r=tr[k].r; if(x==tr[k].val){ if(tr[k].cnt>1){ tr[k].cnt--;tr[k].size--;return; } if(l*r==0) k=l+r; else{ if(tr[l].rnd<tr[r].rnd) rturn(k); else lturn(k); del(k,x); } } else{ tr[k].size--; if(x>tr[k].val) del(r,x); else del(l,x); } } int rnk(int &k,int x){ if(!k) return 0; int &l=tr[k].l,&r=tr[k].r; if(tr[k].val==x) return tr[l].size+1; if(tr[k].val>x) return rnk(l,x); if(tr[k].val<x) return tr[l].size+tr[k].cnt+rnk(r,x); } int kth(int &k, int x){ if(!k) return 0; int &l=tr[k].l,&r=tr[k].r; if(tr[l].size+1<=x&&tr[l].size+tr[k].cnt>=x) return tr[k].val; if(tr[l].size>=x) return kth(l,x); if(tr[l].size+tr[k].cnt<x) return kth(r,x-tr[l].size-tr[k].cnt); } int pred(int &k,int val){ if(!k) return -inf; int &l=tr[k].l,&r=tr[k].r; if(tr[k].val>=val) return pred(l,val); return max(pred(r,val),tr[k].val); } int succ(int &k,int val){ if(!k) return inf; int &l=tr[k].l,&r=tr[k].r; if(tr[k].val<=val) return succ(r,val); return min(succ(l,val),tr[k].val); } int main(){ srand(time(0)); scanf("%d",&n); for(int i=1,opt,x;i<=n;i++){ scanf("%d%d",&opt,&x); switch(opt){ case 1:insert(rt,x);break; case 2:del(rt,x);break; case 3:printf("%d\n",rnk(rt,x));break; case 4:printf("%d\n",kth(rt,x));break; case 5:printf("%d\n",pred(rt,x));break; case 6:printf("%d\n",succ(rt,x));break; } } return 0; }