【洛谷P3369 普通平衡树】
题目描述
您需要写一种数据结构,来维护一些数,其中需要提供以下操作:
-
插入x数
-
删除x数(若有多个相同的数,因只删除一个)
-
查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)
-
查询排名为x的数
-
求x的前驱(前驱定义为小于x,且最大的数)
- 求x的后继(后继定义为大于x,且最小的数)
输入格式:
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号
输出格式:
对于操作3,4,5,6每行输出一个数,表示对应答案
输入样例:
10 1 106465 4 1 1 317721 1 460929 1 644985 1 84185 1 89851 6 81968 1 492737 5 493598
输出样例:
106465 84185 492737
题解:
这题就在这里写下单点问题的splay问题总结吧。(一个一个问题来)
1.旋转
1 inline void update(int r) 2 { 3 if(r!=-1){ 4 sz[r]=num[r]; 5 if(ch[r][0]!=-1) sz[r]+=sz[ch[r][0]]; 6 if(ch[r][1]!=-1) sz[r]+=sz[ch[r][1]]; 7 } 8 } 9 inline void rotate(int x,int k) 10 { 11 int y=fa[x];int z=fa[y]; 12 ch[y][1-k]=ch[x][k];if(ch[x][k]!=-1) fa[ch[x][k]]=y; 13 fa[x]=z;if(z!=-1) ch[z][y==ch[z][1]]=x; 14 ch[x][k]=y;fa[y]=x; 15 update(y);update(x); 16 }
因为题目说会有相同的数字所以开一个num数组记录这个节点相同sum的数量,sz为这个子树的大小(num总和),sum为这个节点代表的值(就是splay维护的key),每次旋转完要更新与rotate有关两个点的sz,所有要update(重点:所有改变某个节点的num操作都要把这个节点以及它的fa去update,否则的话进行splay会出问题,我也是想了很久的)
2.splay
1 inline void splay(int p) 2 { 3 while(fa[p]!=-1) 4 { 5 int y=fa[p];int z=fa[y]; 6 if(z==-1) rotate(p,ch[y][1]!=p); 7 else 8 { 9 if(ch[z][0]==y){ 10 if(ch[y][0]==p) rotate(y,1),rotate(p,1); 11 else rotate(p,0),rotate(p,1); 12 } 13 else{ 14 if(ch[y][0]==p) rotate(p,1),rotate(p,0); 15 else rotate(y,0),rotate(p,0); 16 } 17 } 18 } 19 rt=p; 20 }
这个理解的话也不难,伸展操作然后把某个节点调到rt位置。(重点:关于所有操作你找到点都要把这个点splay了,这样才能维护树的平衡)
3.插入
1 inline void build(int v,int f){ 2 fa[++cnt]=f,ch[cnt][0]=ch[cnt][1]=-1,sum[cnt]=v; 3 sz[cnt]=1,num[cnt]=1; 4 } 5 inline void insert(int v) 6 { 7 int r=rt; 8 while(1){ 9 if(v==sum[r]){ 10 num[r]++; 11 update(r),update(fa[r]); 12 splay(r);return; 13 } 14 if(v<sum[r]){ 15 if(ch[r][0]==-1){ 16 build(v,r); 17 ch[r][0]=cnt; 18 update(cnt),update(r); 19 splay(cnt);return; 20 } 21 else r=ch[r][0]; 22 } 23 else 24 { 25 if(ch[r][1]==-1){ 26 build(v,r); 27 ch[r][1]=cnt; 28 update(cnt),update(r); 29 splay(cnt);return; 30 } 31 else r=ch[r][1]; 32 } 33 } 34 }
build是对新建的点进行处理,插入跟普通的二叉树没有什么区别,记得两个重点。
4.询问第k小数以及x是第几小
1 inline int find_pos(int x) 2 { 3 int r=rt; 4 while(1) 5 { 6 if(sum[r]==x)break; 7 if(x<sum[r]) r=ch[r][0]; 8 else r=ch[r][1]; 9 } 10 splay(r); 11 return sz[ch[r][0]]+1; 12 } 13 inline int find_num(int x) 14 { 15 int r=rt; 16 while(1){ 17 if(sz[ch[r][0]]+1<=x && x<=sz[ch[r][0]]+num[r]) 18 { 19 splay(r); 20 return sum[r]; 21 } 22 if(sz[ch[r][0]]+num[r]<x) 23 { 24 x-=(sz[ch[r][0]]+num[r]); 25 r=ch[r][1]; 26 } 27 else r=ch[r][0]; 28 } 29 }
对于询问x是第几小,很简单找到这个点把它splay到根然后输出ch[x][0]的sz,不要忘了+1.
对于询问第k小数,真的难讲,看代码吧。(记得if(sz[ch[r][0]]+1<=x && x<=sz[ch[r][0]]+num[r]),r节点就是答案)
5.前驱后继
1 inline int pre(int x) 2 { 3 insert(x); 4 int r=ch[rt][0]; 5 while(ch[r][1]!=-1) r=ch[r][1]; 6 return sum[r]; 7 } 8 inline int suc(int x) 9 { 10 insert(x); 11 int r=ch[rt][1]; 12 while(ch[r][0]!=-1) r=ch[r][0]; 13 return sum[r]; 14 }
先加入这个数x,把它splay到根,然后左儿子开始一直往右就是前驱,右儿子开始一直往左就是后继。然后把这个数删了。
6.删除
1 inline void del(int x) 2 { 3 int r=rt,p; 4 while(1){ 5 if(sum[r]==x){ 6 num[r]--; 7 if(num[r]>=1){ 8 update(r);update(fa[r]); 9 splay(r);return; 10 } 11 else{ 12 p=r; 13 break; 14 } 15 } 16 if(x<sum[r]) r=ch[r][0]; 17 else r=ch[r][1]; 18 } 19 splay(p); 20 if(ch[p][0]==-1 && ch[p][1]==-1){ 21 rt=-1; 22 return; 23 } 24 if(ch[p][0]==-1 && ch[p][1]!=-1){ 25 rt=ch[p][1]; 26 fa[ch[p][1]]=-1; 27 return; 28 } 29 if(ch[p][0]!=-1 && ch[p][1]==-1){ 30 rt=ch[p][0]; 31 fa[ch[p][0]]=-1; 32 return; 33 } 34 int mx=ch[p][0]; 35 while(ch[mx][1]!=-1) mx=ch[mx][1]; 36 fa[ch[p][0]]=-1; 37 splay(mx); 38 ch[mx][1]=ch[p][1]; 39 fa[ch[p][1]]=mx; 40 }
如果这个点的num大于1就说明这个节点不用删掉只需把num--就好(记得两个重点)。否则,删除节点其实是两个操作,先删后并。先特判没有左右儿子的情况(easy),然后就不舒服了。如何使得删掉这个值还能满足平衡树性质?先把rt的左儿子与它断开(fa[ch[rt][0]=-1),在左儿子中找到最大的数(就是一直往右),把它splay到根节点然后把原本的右儿子接进去就好了。(不理解好好想想)
总结:1.所有改变某个节点的num(或者sz)操作或者改变子节点操作都要把这个节点以及它的fa去update
2.关于所有操作你找到点都要把这个点splay了,这样才能维护树的平衡(否则要splay何用?)
3.num修改操作都是先把它与它父亲的sz更新后再把这个节点splay
完整代码
1 #include<iostream> 2 #include<cstdlib> 3 #include<cstdio> 4 #include<cstring> 5 using namespace std; 6 int ch[100005][3],sum[100005],sz[100005],num[100005],fa[100005]; 7 int rt=-1,cnt; 8 inline void update(int r) 9 { 10 if(r!=-1){ 11 sz[r]=num[r]; 12 if(ch[r][0]!=-1) sz[r]+=sz[ch[r][0]]; 13 if(ch[r][1]!=-1) sz[r]+=sz[ch[r][1]]; 14 } 15 } 16 inline void rotate(int x,int k) 17 { 18 int y=fa[x];int z=fa[y]; 19 ch[y][1-k]=ch[x][k];if(ch[x][k]!=-1) fa[ch[x][k]]=y; 20 fa[x]=z;if(z!=-1) ch[z][y==ch[z][1]]=x; 21 ch[x][k]=y;fa[y]=x; 22 update(y);update(x); 23 } 24 inline void splay(int p) 25 { 26 while(fa[p]!=-1) 27 { 28 int y=fa[p];int z=fa[y]; 29 if(z==-1) rotate(p,ch[y][1]!=p); 30 else 31 { 32 if(ch[z][0]==y){ 33 if(ch[y][0]==p) rotate(y,1),rotate(p,1); 34 else rotate(p,0),rotate(p,1); 35 } 36 else{ 37 if(ch[y][0]==p) rotate(p,1),rotate(p,0); 38 else rotate(y,0),rotate(p,0); 39 } 40 } 41 } 42 rt=p; 43 } 44 inline void build(int v,int f){ 45 fa[++cnt]=f,ch[cnt][0]=ch[cnt][1]=-1,sum[cnt]=v; 46 sz[cnt]=1,num[cnt]=1; 47 } 48 inline void insert(int v) 49 { 50 int r=rt; 51 while(1){ 52 if(v==sum[r]){ 53 num[r]++; 54 update(r),update(fa[r]); 55 splay(r);return; 56 } 57 if(v<sum[r]){ 58 if(ch[r][0]==-1){ 59 build(v,r); 60 ch[r][0]=cnt; 61 update(cnt),update(r); 62 splay(cnt);return; 63 } 64 else r=ch[r][0]; 65 } 66 else 67 { 68 if(ch[r][1]==-1){ 69 build(v,r); 70 ch[r][1]=cnt; 71 update(cnt),update(r); 72 splay(cnt);return; 73 } 74 else r=ch[r][1]; 75 } 76 } 77 } 78 inline void del(int x) 79 { 80 int r=rt,p; 81 while(1){ 82 if(sum[r]==x){ 83 num[r]--; 84 if(num[r]>=1){ 85 update(r);update(fa[r]); 86 splay(r);return; 87 } 88 else{ 89 p=r; 90 break; 91 } 92 } 93 if(x<sum[r]) r=ch[r][0]; 94 else r=ch[r][1]; 95 } 96 splay(p); 97 if(ch[p][0]==-1 && ch[p][1]==-1){ 98 rt=-1; 99 return; 100 } 101 if(ch[p][0]==-1 && ch[p][1]!=-1){ 102 rt=ch[p][1]; 103 fa[ch[p][1]]=-1; 104 return; 105 } 106 if(ch[p][0]!=-1 && ch[p][1]==-1){ 107 rt=ch[p][0]; 108 fa[ch[p][0]]=-1; 109 return; 110 } 111 int mx=ch[p][0]; 112 while(ch[mx][1]!=-1) mx=ch[mx][1]; 113 fa[ch[p][0]]=-1; 114 splay(mx); 115 ch[mx][1]=ch[p][1]; 116 fa[ch[p][1]]=mx; 117 } 118 inline int pre(int x) 119 { 120 insert(x); 121 int r=ch[rt][0]; 122 while(ch[r][1]!=-1) r=ch[r][1]; 123 return sum[r]; 124 } 125 inline int suc(int x) 126 { 127 insert(x); 128 int r=ch[rt][1]; 129 while(ch[r][0]!=-1) r=ch[r][0]; 130 return sum[r]; 131 } 132 inline int find_pos(int x) 133 { 134 int r=rt; 135 while(1) 136 { 137 if(sum[r]==x)break; 138 if(x<sum[r]) r=ch[r][0]; 139 else r=ch[r][1]; 140 } 141 splay(r); 142 return sz[ch[r][0]]+1; 143 } 144 inline int find_num(int x) 145 { 146 int r=rt; 147 while(1){ 148 if(sz[ch[r][0]]+1<=x && x<=sz[ch[r][0]]+num[r]) 149 { 150 splay(r); 151 return sum[r]; 152 } 153 if(sz[ch[r][0]]+num[r]<x) 154 { 155 x-=(sz[ch[r][0]]+num[r]); 156 r=ch[r][1]; 157 } 158 else r=ch[r][0]; 159 } 160 } 161 int main() 162 { 163 int n,op,x; 164 scanf("%d",&n); 165 for(int i=1;i<=n;i++) 166 { 167 scanf("%d%d",&op,&x); 168 if(op==1){ 169 if(rt==-1){ 170 build(x,-1); 171 rt=cnt; 172 } 173 else insert(x); 174 } 175 if(op==2) del(x); 176 if(op==3) printf("%d\n",find_pos(x)); 177 if(op==4) printf("%d\n",find_num(x)); 178 if(op==5) printf("%d\n",pre(x)),del(x); 179 if(op==6) printf("%d\n",suc(x)),del(x); 180 } 181 return 0; 182 }