[bzoj3224] 普通平衡树
Description
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
Input
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
Output
对于操作3,4,5,6每行输出一个数,表示对应答案
Sample Input
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
Sample Output
106465
84185
492737
84185
492737
HINT
1.n的数据范围:n<=100000
2.每个数的数据范围:[-2e9,2e9]
题解:
splay
操作前要设一个上界-inf,防止访问上界时把空结点转到根......
#include<cstdio> #include<cstdlib> #include<cstring> #include<iostream> #include<cmath> #include<algorithm> #define RG register using namespace std; const int maxn = 100010; int n,id,root,inf=1<<30; int ch[maxn][2],key[maxn],pre[maxn],siz[maxn]; void newnode(int &x, int fa, int val) { x=++id,key[x]=val,pre[x]=fa; ch[x][1]=ch[x][0]=0; } void rotate(int x, int kd) { int y=pre[x]; ch[y][!kd]=ch[x][kd]; pre[ch[x][kd]]=y; if(pre[y]) ch[pre[y]][ch[pre[y]][1]==y]=x; pre[x]=pre[y],ch[x][kd]=y,pre[y]=x; siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1; siz[y]=siz[ch[y][0]]+siz[ch[y][1]]+1; } void splay(int x, int goal) { while(pre[x]!=goal) { if(pre[pre[x]]==goal) rotate(x,ch[pre[x]][0]==x); else { int y=pre[x],kd=ch[pre[y]][0]==y; if(ch[y][!kd]==x) rotate(y,kd),rotate(x,kd); else rotate(x,!kd),rotate(x,kd); } } if(goal==0) root=x; } void insert(int k) { while(ch[root][k>key[root]]) { siz[root]++; root=ch[root][k>key[root]]; } siz[root]++; newnode(ch[root][k>key[root]],root,k); siz[ch[root][k>key[root]]]++; splay(ch[root][k>key[root]],0); } void find_pre(int x, int k, int &ans) { if(!x) return; if(k>key[x]) {ans=x,find_pre(ch[x][1],k,ans);return;} else {find_pre(ch[x][0],k,ans);return;} } void find_nxt(int x, int k, int &ans) { if(!x) return; if(k<key[x]) {ans=x,find_nxt(ch[x][0],k,ans);return;} else {find_nxt(ch[x][1],k,ans);return;} } void erase(int k) { int prek=root; find_pre(root,k,prek); int nxtk=root; find_nxt(root,k,nxtk); splay(prek,0),splay(nxtk,prek); int tmp=ch[nxtk][0]; while(ch[tmp][0] || ch[tmp][1]) { siz[tmp]--; if(ch[tmp][0]) tmp=ch[tmp][0]; else tmp=ch[tmp][1]; } if(ch[pre[tmp]][0]==tmp) ch[pre[tmp]][0]=0; else ch[pre[tmp]][1]=0; } void query_rank(int k) { int prek=root; find_pre(root,k,prek); splay(prek,0); printf("%d\n", siz[ch[root][0]]+1); } void query_num(int k) {//k是排名 int rt=root; while(rt) { if(siz[ch[rt][0]]+1==k) {printf("%d\n", key[rt]);return;} if(siz[ch[rt][0]]<k) k-=siz[ch[rt][0]]+1,rt=ch[rt][1]; else rt=ch[rt][0]; } } void query_pre(int x, int k, int &ans) { if(!x) return; if(k>key[x]) {ans=key[x],query_pre(ch[x][1],k,ans);return;} else {query_pre(ch[x][0],k,ans);return;} } void query_nxt(int x, int k, int &ans) { if(!x) return; if(k<key[x]) {ans=key[x],query_nxt(ch[x][0],k,ans);return;} else {query_nxt(ch[x][1],k,ans);return;} } int main() { scanf("%d", &n); newnode(root,0,-inf),insert(inf); for(int i=1; i<=n; i++) { int kd,x; scanf("%d%d", &kd, &x); if(kd==1) insert(x); if(kd==2) erase(x); if(kd==3) query_rank(x); if(kd==4) query_num(x+1); if(kd==5) { int ans=-1; query_pre(root,x,ans); printf("%d\n", ans); } if(kd==6) { int ans=-1; query_nxt(root,x,ans); printf("%d\n", ans); } } return 0; }