C++Splay实现过程(附代码)
前言:
作者的splay写法主要受到https://www.luogu.org/blog/tiger0132/slay-notes这篇文章的影响,所以代码实现基本重复,这里主要把那篇文章中没讲清楚的地方详细讲讲
作者水平有限,可能部分描述有误,欢迎指出!
另外作者码风略丑,见谅QWQ
Tips:本文的Splay不含维护区间的方法,以后会另开文章介绍
_____________________________________________________________________________________________________
Splay,又称伸展树,是平衡树的一种(蒟蒻我也只会写一种QWQ)
在讲平衡树之前,得先了解什么是二叉查找树
二叉搜索树,又称二叉搜索树,是一种优化查询数字的复杂度的数据结构,他的结点严格满足左结点小于根节点,右结点大于根节点
因此每次查找某个特定值,只要不断将它与当前结点比较大小,选择向左儿子走或者右儿子走,就一定能找到要找到的数字,下图就是一个标准的二叉搜索树:
(画图随便画的,很丑QWQ)
上图就是一个优秀的例子,为什么?因为这恰好是一个完全二叉树,也就是说,你查询任何数的复杂度都能稳定在O(log n)级别
但是真的每次都这样吗?:
上图的这颗搜索树就由于插入顺序的原因退化成了一条链,最差情况下查询一个数字居然达到了O(n)的级别,这显然不是我们能接受的
于是平衡树这一数据结构就诞生了,它的本质其实还是一个二叉搜索树,但是它会通过一些操作来维护这棵树保证它的复杂度控制在O(log n)的级别,也就是防止这棵树退化为链。
而我们今天要说的Splay树,则是通过splay(伸展)这一操作来实现这一过程。
_____________________________________________________________________________________________________
首先我们先确认变量的意义:
int son[x][0/1];-->节点x的左儿子(0)和右儿子(1)
int val[x];- ->节点x的值
int cnt[x];-->有多少个这个值
int size[x]; -->节点x所在子树元素个数
int fa[x]; -->节点x的父亲节点的下标
之后再确认些基本操作
1.
1 bool Dir(int x){ return son[fa[x]][1]==x; }
返回0/1,表示节点x位于父节点的左儿子还是右儿子,其实本质上就是问节点x是否大于父节点,所以也可以写成:
1 bool Dir(int x){ return val[x]>val[fa[x]]; }
Tips:由于bool类型的性质,当返回的逻辑式是正确的就会返回1,否则就会返回0
2.
void pushup(int x){ size[x]=size[son[x][0]]+size[son[x][1]]+cnt[x]; }
更新节点x的size值,在节点旋转或者删除的时候起作用。
_____________________________________________________________________________________________________
刚才也说了,splay操作是Splay树的重中之重,所以我们先来讲这个操作,但是再讲这个之前,必须先介绍一下旋转(rotate)这一操作
什么是旋转?看了好几篇文章都是一笔带过,这里详细讲讲:
旋转就是在不破坏二叉搜索树的性质的前提下将当前节点与父亲节点进行交换
在这一过程中 ,其实就是三条树边连接发生的变化:
1.当前节点的爷爷节点,即父节点的父节点,原来指向父节点的连边现在指向了当前节点,毕竟两点位置交换了;
2.当前节点的儿子变成了父节点,那么父节点应该插在左边还是右边呢?运用刚才上面所说的Dir操作,得出当前节点在父节点的位置情况,也就是大小关系,反过来就可以了,比如当前节点原来是父节点的左儿子,那旋转后父节点一定会成为自己的右儿子,因为右儿子一定是比自己大的嘛,而自己原来是左儿子说明父节点比自己大。
3.当前节点的儿子由于要变成父节点,原来这个位置的子节点就要变成父节点的子节点,插入的位置与当前节点原来在父节点的位置是相同的。
这三个过程搞不懂建议在纸上进行模拟,这里就不画图了,其实并不难理解,另外,由于位置的更改,当前节点和父节点的size值也要更新,这里就可以运用刚才介绍的pushup操作,上代码:
1 void rotate(int x){
2
3 int y=fa[x],z=fa[y],k=Dir(x),w=son[x][k^1];
4 son[z][Dir(y)]=x; fa[x]=z;
5 son[x][k^1]=y; fa[y]=x;
6 son[y][k]=w; fa[w]=y;
7 pushup(x); pushup(y);
8
9 }
_____________________________________________________________________________________________________
讲完这些,终于可以讲讲伸展(splay)操作了
伸展操作一共有两个参数,一个是要伸展的点,一个是要到达的目标,对于splay(x,goal)表示x这个节点要成为goal这个节点的儿子,若goal=0则表示当前节点要成为根节点。而实现过程其实很暴力,就是判断当前该节点的父亲是否为goal,不是的话就不断向上旋转。
Splay就是通过这一操作不断在每次操作后将操作的节点伸展至根从而来维护这条路径上的平衡(不理解可以自己列个图),由于每次都这样保证了整棵树的平衡,伸展操作一般都只会花费O(log n)的时间。
不过这样会出现一个新的问题,当祖宗三代都在同一方向即如下图:
当我们旋转过后:
没错,居然还是形成了链,这样显然会影响整棵树的平衡,所以当我们遇到三代连成一条线的时候,可以先旋转父节点,再旋转当前节点,这样就会保持它的平衡了。
代码如下:
1 inline void splay(int x,int goal=0){//没有填写伸展目标是默认为根节点
2
3 while (fa[x]!=goal){
4 int y=fa[x],z=fa[y];
5 if (z!=goal) {
6 if(Dir(x)==Dir(y)) rotate(y);
7 else rotate(x);
8 }
9 rotate(x);
10 }
11 if(!goal) Root=x;//当前节点伸展目标为根,更新根节点位置
12 }
_____________________________________________________________________________________________________
这一个操作介绍完,剩下的操作其实都不难理解。
接下来以洛谷上的P3369 【模板】普通平衡树为例来介绍剩下的一些操作
1.find(查询)操作
通过二叉搜索树的性质找到特定的某个数字x,并将他伸展至根节点,方便操作。
注意:如果要查找的数字不存在,就会返回这个数字的前驱(前驱定义为小于x,且最大的数) 或者后继(后继定义为大于x,且最小的数)。
1 void find(int x){
2 if(!Root) return;
3 int cur=Root;
4 while(son[cur][x>val[cur]]&&val[cur]!=x)
5 cur=son[cur][x>val[cur]];
6 splay(cur);
7 }
2.insert(插入)操作
插入一个数字,过程类似于find操作,不断往下找同时判断当前节点是否为目标节点,如果树中已有该值则将其的cnt++,否则当走到空节点时就将该数保存在此节点,并将各个数字初始化。
最后将此节点伸展至根,在维护平衡的同时更新各个点的size值。
1 void Insert(int x){
2 int cur=Root,p=0;
3 while(cur&&val[cur]!=x)
4 {
5 p=cur;
6 cur=son[cur][x>val[cur]];
7 }
8 if(cur) cnt[cur]++;
9 else {
10 cur=++tot;
11 if(p) son[p][x>val[p]]=cur;
12 son[cur][0]=son[cur][1]=0;
13 val[cur]=x;fa[cur]=p;
14 cnt[cur]=size[cur]=1;
15 }
16 splay(cur);
17 }
3.求x的前驱/后驱操作
前驱 后继的定义前面find操作已经介绍了,这里就不过多赘述了
由于二者实现过程基本一样,放在一起讲。
过程很简单,只要把x伸展至根后,求前驱就是x最大也就是最右边的左儿子,而后继就是x最小也就是最左边的右儿子。
注意,如果x不存在,前驱/后继之一可能会被splay至根,所以额外进行一次特判
1 inline int Last(int x)//求前驱 (前驱定义为小于x,且最大的数)
2 {
3 find(x);
4 if(val[Root]<x) return Root;//x不存在,根就可能为前驱/后驱
5 int cur=son[Root][0];
6 while(son[cur][1]) cur=son[cur][1];
7 splay(cur);
8 return cur;
9
10 }
11
12 inline int Next(int x)//求后驱 (后继定义为大于x,且最小的数)
13 {
14 find(x);
15 if(val[Root]>x) return Root;
16 int cur=son[Root][1];
17 while(son[cur][0]) cur=son[cur][0];
18 splay(cur);
19 return cur;
20 }
21
4.Delete(删除)操作
这个删除操作非常简单,可以说是我已知平衡树中最方便的了。
首先不难发现,对于某个数x,x的前驱与后继之间的数字一定只有x。
那么我们就可以通过这一特点来删除一个点。
首先我们找到这个数字的后继同时将他伸展至根,然后再找到他的后继使他成为根的儿子,最后结果就会像这样:
此时我们只要把后继的左儿子删去就好了。如果x的数量不止为1就将数量-1同时将x伸展至根来更新size值。
最后更新根和后继的size值即可。
1 void Delete(int x){
2 int L=Last(x),R=Next(x);
3 splay(L);splay(R,L);
4 int del=son[R][0];
5 if(cnt[del]>1) <% cnt[del]--;splay(del); %>
6 else son[R][0]=0;
7 pushup(R);pushup(Root);
8 }
5.GetRank(查询排名)操作
首先我们先find(x)来把要查询的数字伸展至根。
因为排名的定义是比自己小的数字的数量+1,所以输出只要输出根节点左子树节点的数量,然后+1。
但是!如果查询排名的数字不存在,就会把该数的前驱或后继挪到根上,如果是后继还好,如果是前驱的话,由于前驱也比x小,按照上面的输出方法会把根给漏掉,所以特判即可。
但是最后一句,由于后面求K小的操作可能会越界导致死循环,所以个人会在代码开始时先插入一个极大和极小值,所以所有数字的排名都会多1,故所有排名输出时我都会-1,代码如下:
1 int GetRank(int x)
2 {
3 find(x);
4 if(val[Root]<x) return size[son[Root][0]]+cnt[Root];//如果为前驱就把根节点的数量也算上
5 else return size[son[Root][0]];//此处和上面原本应该+1但是如上的原因-1所以抵消了
6 }
6.Kth(查询第K小)操作
从根开始往下找,如果当前节点的左子树大于等于k就往左走,否则就看往右走,同时减去左子树节点数量,直到走不了,由于没有节点这个操作会死循环,所以先插入一个极大和极小值,同样的,由于多一个数字排名多1,所以查询的时候要写成Kth(k+1)不能直接写k。
1 int Kth(int k){
2 int cur=Root;
3 while(1){
4 if(k<=size[son[cur][0]]&&son[cur][0])
5 cur=son[cur][0];
6 else
7 if(k>size[son[cur][0]]+cnt[cur]){
8 k-=size[son[cur][0]]+cnt[cur];
9 cur=son[cur][1];
10 }
11 else{
12 splay(cur);
13 return cur;
14 }
15 }
16 }
_____________________________________________________________________________________________________
至此,普通平衡树的所有操作都写完了,之后有空还会在写一个文艺平衡树也就是平衡树维护区间的博客,希望大家滋瓷QWQ!
下面是这题的总的AC代码:
1 #include<bits/stdc++.h>
2 using namespace std;
3 const int inf=6*200100;
4 int tot,Root;
5
6 int son[inf][2];// 0->left 1->right
7 int val[inf];// 节点的值
8 int cnt[inf];//有多少个这个值
9 int size[inf];// 所在子树元素个数
10 int fa[inf];// 父亲节点下标
11
12 inline bool Dir(int x){ return son[fa[x]][1]==x; }
13
14 inline void pushup(int x)<% size[x]=size[son[x][0]]+size[son[x][1]]+cnt[x]; %>
15
16 inline void rotate(int x){//旋转
17 int y=fa[x],z=fa[y],k=Dir(x),w=son[x][k^1];
18 son[y][k]=w; fa[w]=y;
19 son[z][Dir(y)]=x; fa[x]=z;
20 son[x][k^1]=y; fa[y]=x;
21 pushup(y);pushup(x);
22 }
23
24 inline void splay(int x,int goal=0){//没有填写伸展目标是默认为根节点
25
26 while (fa[x]!=goal){
27 int y=fa[x],z=fa[y];
28 if (z!=goal) {
29 if(Dir(x)==Dir(y)) rotate(y);
30 else rotate(x);
31 }
32 rotate(x);
33 }
34 if(!goal) Root=x;//当前节点伸展目标为根,更新根节点位置
35 }
36
37 void find(int x){//寻找x并将他伸展至根,如果x不存在,就会找到当前数的前驱或后继
38 int cur=Root;
39 while(son[cur][x>val[cur]]&&val[cur]!=x)
40 cur=son[cur][x>val[cur]];
41 splay(cur);
42 }
43
44 inline void Insert(int x){//插入
45 int cur=Root,p=0;
46 while(cur&&val[cur]!=x)
47 {
48 p=cur;
49 cur=son[cur][x>val[cur]];
50 }
51 if(cur) cnt[cur]++;
52 else {
53 cur=++tot;
54 if(p) son[p][x>val[p]]=cur;
55 son[cur][0]=son[cur][1]=0;
56 val[cur]=x;fa[cur]=p;
57 cnt[cur]=size[cur]=1;
58 }
59 splay(cur);
60 }
61
62 inline int Last(int x)//求前驱 (前驱定义为小于x,且最大的数)
63 {
64 find(x);
65 if(val[Root]<x) return Root;//x不存在,根就可能为前驱/后驱
66 int cur=son[Root][0];
67 while(son[cur][1]) cur=son[cur][1];
68 splay(cur);
69 return cur;
70
71 }
72
73 inline int Next(int x)//求后驱 (后继定义为大于x,且最小的数)
74 {
75 find(x);
76 if(val[Root]>x) return Root;
77 int cur=son[Root][1];
78 while(son[cur][0]) cur=son[cur][0];
79 splay(cur);
80 return cur;
81 }
82
83 void Delete(int x){
84 int L=Last(x),R=Next(x);
85 splay(L);splay(R,L);
86 int del=son[R][0];
87 if(cnt[del]>1) <% cnt[del]--;splay(del); %>
88 else son[R][0]=0;
89 pushup(R);pushup(Root);
90 }
91
92 int Kth(int k){
93 int cur=Root;
94 while(1){
95 if(k<=size[son[cur][0]]&&son[cur][0])
96 cur=son[cur][0];
97 else
98 if(k>size[son[cur][0]]+cnt[cur]){
99 k-=size[son[cur][0]]+cnt[cur];
100 cur=son[cur][1];
101 }
102 else{
103 splay(cur);
104 return cur;
105 }
106 }
107 }
108
109 int GetRank(int x)
110 {
111 find(x);
112 if(val[Root]<x) return size[son[Root][0]]+cnt[Root];
113 else return size[son[Root][0]];
114 }
115 int main()
116 {
117 int N;
118 scanf("%d",&N);
119 Insert(INT_MAX);
120 Insert(-INT_MAX);
121 for(int i=1;i<=N;i++){
122 int opt,x;
123 scanf("%d%d",&opt,&x);
124 if(opt==1) Insert(x);
125 if(opt==2) Delete(x);
126 if(opt==3) <% find(x); printf("%d\n",GetRank(x));%>
127 if(opt==4) <% printf("%d\n",val[Kth(x+1)]);%>
128 if(opt==5) <% printf("%d\n",val[Last(x)]);%>
129 if(opt==6) <% printf("%d\n",val[Next(x)]);%>
130 }
131 return 0;
132 }