BZOJ 3224 平衡树模板题
Treap:
//By SiriusRen
#include <cstdio>
#include <algorithm>
using namespace std;
int n,op,xx,ans,size,root;
struct Treap{int ch[2],v,cnt,rnd,sz;}tr[300000];
void Upd(int k){tr[k].sz=tr[k].cnt+tr[tr[k].ch[0]].sz+tr[tr[k].ch[1]].sz;}
void rot(int &k,bool f){int t=tr[k].ch[f];tr[k].ch[f]=tr[t].ch[!f],tr[t].ch[!f]=k,Upd(k),Upd(t),k=t;}
void ins(int &k){
if(!k){k=++size;tr[k].cnt=tr[k].sz=1;tr[k].rnd=rand(),tr[k].v=xx;return;}
tr[k].sz++;
if(tr[k].v==xx){tr[k].cnt++;return;}
bool f=xx>tr[k].v;ins(tr[k].ch[f]);
if(tr[tr[k].ch[f]].rnd<tr[k].rnd)rot(k,f);
}
void del(int &k){
if(tr[k].v==xx){
if(tr[k].cnt>1)tr[k].sz--,tr[k].cnt--;
else if(!(tr[k].ch[0]*tr[k].ch[1]))k=max(tr[k].ch[0],tr[k].ch[1]);
else rot(k,tr[tr[k].ch[0]].rnd>tr[tr[k].ch[1]].rnd),del(k);
}
else tr[k].sz--,del(tr[k].ch[xx>tr[k].v]);
}
int get_rank(int k){
if(tr[k].v==xx)return tr[tr[k].ch[0]].sz+1;
else if(tr[k].v>xx)return get_rank(tr[k].ch[0]);
else return get_rank(tr[k].ch[1])+tr[tr[k].ch[0]].sz+tr[k].cnt;
}
int get_kth(int k,int x){
if(tr[tr[k].ch[0]].sz>=x)return get_kth(tr[k].ch[0],x);
else if(tr[tr[k].ch[0]].sz+tr[k].cnt<x)return get_kth(tr[k].ch[1],x-tr[tr[k].ch[0]].sz-tr[k].cnt);
else return tr[k].v;
}
void get(int k){
if(!k)return;
if(op==5&&tr[k].v<xx)ans=tr[k].v,get(tr[k].ch[1]);
else if(op==5&&tr[k].v>=xx)get(tr[k].ch[0]);
else if(op==6&&tr[k].v>xx)ans=tr[k].v,get(tr[k].ch[0]);
else get(tr[k].ch[1]);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&op,&xx);
if(op==1)ins(root);
else if(op==2)del(root);
else if(op==3)printf("%d\n",get_rank(root));
else if(op==4)printf("%d\n",get_kth(root,xx));
else get(root),printf("%d\n",ans);
}
}
Splay:
//By SiriusRen
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int op,xx,root,n,size;
struct Splay{int ch[2],fa,v,sz,cnt;}tr[300500];
void Upd(int x){tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt;}
void rot(int x){
int y=tr[x].fa,z=tr[y].fa;
bool f=(tr[y].ch[1]==x);
tr[y].ch[f]=tr[x].ch[!f];
if(tr[y].ch[f])tr[tr[y].ch[f]].fa=y;
tr[x].ch[!f]=y;tr[y].fa=x;tr[x].fa=z;
if(z)tr[z].ch[tr[z].ch[1]==y]=x;
Upd(y);
}
void splay(int x,int tp){
for(int y,z;(y=tr[x].fa)!=tp;rot(x)){
z=tr[y].fa;
if(z==tp)continue;
if((tr[y].ch[0]==x)==(tr[z].ch[0]==y))rot(y);
else rot(x);
}
if(!tp)root=x;
Upd(x);
}
void insert(int x,int num){
int y=0;
while(x&&tr[x].v!=num)y=x,x=tr[x].ch[num>tr[x].v];
if(x)++tr[x].cnt;
else{
x=++size,tr[x].sz=tr[x].cnt=1,tr[x].fa=y,tr[x].v=num;
if(y)tr[y].ch[num>tr[y].v]=x;
}
splay(x,0);
}
void find(int v){
int x=root;
while(tr[x].ch[v>tr[x].v]&&tr[x].v!=v)x=tr[x].ch[v>tr[x].v];
splay(x,0);
}
int next(int v,bool f){
find(v);
if((tr[root].v>v&&f)||(tr[root].v<v&&!f))return root;
int p=tr[root].ch[f];
while(tr[p].ch[!f])p=tr[p].ch[!f];
return p;
}
void del(int v){
int p=next(v,0),s=next(v,1);
splay(p,0),splay(s,p);
p=tr[s].ch[0];
if(tr[p].cnt>1)tr[p].cnt--,splay(p,0);
else tr[s].ch[0]=0;
}
int kth(int x){
int y=root,p;
if(x>tr[root].sz)return 0;
while(1){
p=tr[y].ch[0];
if(tr[p].sz+tr[y].cnt<x)x=x-tr[p].sz-tr[y].cnt,y=tr[y].ch[1];
else if(tr[p].sz>=x)y=p;
else return tr[y].v;
}
}
int main(){
insert(root,0x3fffffff),insert(root,-0x3fffffff);
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&op,&xx);
if(op==1)insert(root,xx);
else if(op==2)del(xx);
else if(op==3)find(xx),printf("%d\n",tr[tr[root].ch[0]].sz);
else if(op==4)printf("%d\n",kth(xx+1));
else printf("%d\n",tr[next(xx,op==6)].v);
}
}