loj 107 维护全序集
loj 107 维护全序集
本题是平衡树的模板题,我写了treap和splay
#include<bits/stdc++.h>
using namespace std;
int const N=3e5+10;
int const inf=2e9;
int ch[N][2],f[N],sz[N],same[N],v[N],n,rt,cnt;
inline int typ(int x){ return x==ch[f[x]][1];}
void update(int x){
sz[x]=same[x]+sz[ch[x][0]]+sz[ch[x][1]];
}
int newnode(int val){
v[++cnt]=val;
same[cnt]=sz[cnt]=1;
ch[cnt][0]=ch[cnt][1]=0;
return cnt;
}
void rotate(int x){
int fa=f[x],ffa=f[fa],w=typ(x),ww=typ(fa);
ch[fa][w]=ch[x][w^1];
ch[x][w^1]=fa;
f[fa]=x;
f[x]=ffa;
if(ch[fa][w]) f[ch[fa][w]]=fa;
if(ffa) ch[ffa][ww]=x;
update(fa);
update(x);
}
void splay(int x){
while ( f[x]!=0){
int fa=f[x];
if(f[fa]==0)
rotate(x);
else if(typ(x)==typ(fa))
rotate(fa),rotate(x);
else
rotate(x),rotate(x);
}
rt=x;
}
void ins(int x){
if(!rt){
rt=newnode(x); return ;
}
int k=rt,last;
while (k){
last=k;
if(v[k]==x){
same[k]++; update(k);
splay(k);
return;
}else if(v[k]>x) k=ch[k][0];
else k=ch[k][1];
}
k=newnode(x);
f[k]=last;
ch[last][x>v[last]]=k;
splay(k);
}
int find(int x){
int k=rt;
while (k){
if(v[k]==x){
splay(k);
return k;
}else if(v[k]>x)
k=ch[k][0];
else
k=ch[k][1];
}
}
int smaller(int x){
int res=-1;
int k=rt;
while (k){
if(v[k]<x) res=max(res,v[k]),k=ch[k][1];
else k=ch[k][0];
}
return res;
}
int bigger(int x){
int res=inf;
int k=rt;
while (k){
if(v[k]>x) res=min(res,v[k]),k=ch[k][0];
else k=ch[k][1];
}
if(res==inf) res=-1;
return res;
}
void del(int x){
find(x);
if(same[rt]>1){
same[rt]--;update(rt);
}else if(!ch[rt][0] && !ch[rt][1])
rt=0;
else if(!ch[rt][0] && ch[rt][1])
rt=ch[rt][1],f[rt]=0;
else if(ch[rt][0] && !ch[rt][1])
rt=ch[rt][0],f[rt]=0;
else {
int t=smaller(x);
find(t);
int k=ch[rt][1];
while(ch[k][0]) k=ch[k][0];
int w=typ(k);
ch[f[k]][w]=ch[k][1];
if(ch[k][1]) f[ch[k][1]]=f[k];
sz[f[k]]--;
update(f[k]);
splay(f[k]);
}
}
int kth(int x){
int k=rt;
while (k){
int num=sz[ch[k][0]];
if(num>=x) k=ch[k][0];
else if(num+same[k]>=x) return v[k];
else x-=num+same[k],k=ch[k][1];
}
}
int count(int x){
int res=0;
int k=rt;
while (k){
if(v[k]<x) res+=sz[ch[k][0]]+same[k],k=ch[k][1];
else k=ch[k][0];
}
return res;
}
int main(){
scanf("%d",&n);
while (n--){
int x,y;
scanf("%d%d",&x,&y);
if(x==0) ins(y);
if(x==1) del(y);
if(x==2) printf("%d\n",kth(y));
if(x==3) printf("%d\n",count(y));
if(x==4) printf("%d\n",smaller(y));
if(x==5) printf("%d\n",bigger(y));
}
return 0;
}
#include<bits/stdc++.h>
using namespace std;
int const N=3e5+10;
int son[N][2],sz[N],v[N],sum[N],n,cnt,rt,rd[N];
int newnode(int val){
++cnt; son[cnt][0]=son[cnt][0]=0;
v[cnt]=val;
sum[cnt]=sz[cnt]=1;
rd[cnt]=rand()+1;
return cnt;
}
void update(int o){
sz[o]=sum[o]+sz[son[o][0]]+sz[son[o][1]];
}
void rotate(int &o,int k){
int t=son[o][k^1];
son[o][k^1]=son[t][k];
son[t][k]=o;
update(o); update(t); o=t;
}
void ins(int &o,int val){
if(!o) {
o=newnode(val);
return ;
}
if(v[o]==val){
sum[o]++; update(o); return ;
}
if(v[o]>val) ins(son[o][0],val);
else ins(son[o][1],val);
int k=rd[son[o][1]]>rd[son[o][1]];
if(rd[son[o][k]]>rd[o]) rotate(o,k^1);
update(o);
}
void del(int &o,int val){
if(v[o]==val){
if(sum[o]==1){
if(!son[o][0] && !son[o][1]) o=0;
else if(son[o][0] && !son[o][1]) {
rotate(o,1);
del(son[o][1],val);
update(o);
}else if(!son[o][0] && son[o][1]){
rotate(o,0);
del(son[o][0],val);
update(o);
}else {
int k=rd[son[o][1]]>rd[son[o][0]];
rotate(o,k^1);
del(son[o][k^1],val);
update(o);
}
}
else sum[o]--,update(o);
return ;
}
if(v[o]>val) del(son[o][0],val);
else del(son[o][1],val);
update(o);
}
int kth(int x){
int k=rt;
while (1){
int t=sz[son[k][0]];
if(t>=x) k=son[k][0];
else if(t+sum[k]>=x) return v[k];
else x-=t+sum[k],k=son[k][1];
}
}
int lesth(int x){
int res=0;
int k=rt;
while (k){
if(v[k]<x) res+=sz[son[k][0]]+sum[k],k=son[k][1];
else if(v[k]>=x) k=son[k][0];
}
return res;
}
int smaller(int x){
int res=-1;
int k=rt;
while (k){
if(v[k]<x) res=max(v[k],res),k=son[k][1];
else k=son[k][0];
}
return res;
}
int bigger(int x){
int res=2e9;
int k=rt;
while (k){
if(v[k]>x) res=min(res,v[k]),k=son[k][0];
else k=son[k][1];
}
if(res==2e9) res=-1;
return res;
}
int main(){
srand(time(0));
scanf("%d",&n);
while (n--){
int x,y;
scanf("%d%d",&x,&y);
if(x==0) ins(rt,y);
if(x==1) del(rt,y);
if(x==2) printf("%d\n",kth(y));
if(x==3) printf("%d\n",lesth(y));
if(x==4) printf("%d\n",smaller(y));
if(x==5) printf("%d\n",bigger(y));
}
return 0;
}