Splay浅析
关于这篇文章
《Splay浅析》,一看标题名就知道这篇文章是关于 Splay Tree (又称 伸展树 )的
另外,如果需要学习 Splay Tree 请务必保证了解 二叉搜索树 以及 平衡树 的概念(可能还需要一点旋转的知识)
Splay 是啥呢
Splay Tree 是一种奇妙无比的数据结构,它是将二叉搜索树改进后产生的
因此,它最重要的目的就是使其各类操作尽量保持在 \(O(\log n)\)
但实际上,Splay Tree 是均摊 \(O(\log n)\) 的,也就可以理解为其常数极大
(AVL树表示:均摊 \(O(\log n)\) 也太逊了,看我最坏复杂度 \(O(\log n)\)))
但相应的,Splay Tree 虽然常数大,但是能够完成一些其他二叉搜索树所不能的区间操作
(当然,万能的 fhq-Treap 也是可以的,只不过比 Splay Tree 常数还大)
另外,Splay Tree 维护 Link-Cut Tree 也是这种数据结构的重要用途之一
要想让操作维持在 \(O(\log n)\) ,可以基本概括为三种方法:
- 让二叉搜索树的节点再储存一些其他数据
- 为二叉搜索树增添一些新的操作
- 增添多余节点(即满足 Leafy )
几乎所有的二叉搜索树改进版都引进了新操作
Treap 家族,红黑树家族就是引进了新操作以及储存了其他数据
WBLT 就是引进了新操作以及增添多余节点
而替罪羊树和 Splay Tree 仅引进新操作,
替罪羊树引入了 rbu
,即重构子树
Splay 引入了 splay
接下来,开始介绍 Splay Tree 最重要的操作——splay
splay
操作 —— Splay Tree 名字的由来
首先在了解 splay
操作前,必须先了解旋转操作,即 rotate
rotate
操作如下图所示(图中展示的是右旋 \(2\) 节点,以及 左旋 \(1\) 节点):
这里引用一段 OI-Wiki 中关于旋转操作的话:
具体分析旋转步骤(假设需要旋转的节点为 \(x\) ,其父亲为 \(y\) ,以右旋为例)
- 将 \(y\) 的左儿子指向 \(x\) 的右儿子,且 \(x\) 的右儿子(如果 \(x\) 有右儿子的话)的父亲指向 \(y\) :
ch[y][0]=ch[x][1]; fa[ch[x][1]]=y;
- 将 \(x\) 的右儿子指向 \(y\) ,且 \(y\) 的父亲指向 \(x\) ;
ch[x][chk^1]=y; fa[y]=x;
- 如果原来的 \(y\) 还有父亲 \(z\) ,那么把 \(z\) 的某个儿子(原来 \(y\) 所在的儿子位置)指向 \(x\) ,且 \(x\) 的父亲指向 \(z\) 。
fa[x]=z; if(z) ch[z][y==ch[z][1]]=x;
左旋与右旋之分在于要旋转的节点是它父亲的左儿子还是右儿子
所以可以实现 get
,求出当前节点是左儿子还是右儿子
如下所示:
#define get(x) x==ch[fa[x]][1]
另外再实现一个 clear
,用于清空节点的值:
#define clear(x) fa[x]=ch[x][0]=ch[x][1]=sz[x]=val[x]=cnt[x]=0
以及 update
函数,用于更新子树的大小:
void update(int x){sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];}
然后就可以完成 rotate
函数:
#define touch(x,y,chk) ch[fa[x]=y][chk]=x
inline void rotate(int x)
{
bool yck=get(x),zck=get(fa[x]),xck=!yck;
int y=fa[x],z=fa[y],xs=ch[x][xck];
touch(xs,y,yck),touch(x,z,zck),touch(y,x,xck);
update(y),update(x);
}
完成了 rotate
函数后就可以写 splay
函数了,这个函数用于将某个节点旋转至根的位置
通过不断将节点旋转至根,就可以保证经常访问的节点离根较近,也就保证了均摊 \(O(\log n)\) 的复杂度
splay
函数的代码如下所示:
void splay(int x)
{
for(int cur=fa[x]; cur; rotate(x),cur=fa[x])
if(fa[cur]) rotate(get(x)==get(cur)?cur:x);
rt=x;
}
还是比较好写的,但是请注意其实 splay
操作有六种情况,请读者自己去尝试
这里只给出一种情况:
当 splay(x)
中 \(x\) 的父亲是根节点时,直接将 \(x\) 旋转
剩下的操作都比较简单了,就是二叉搜索树的基本操作,只是要注意每次操作时要进行 splay
另外,现在应该明白为何 Splay Tree 不是平衡树了吧,因为它靠 splay
操作维持复杂度,
所以其左右子树并不平衡
划重点: 这里讲一种巨恐怖的错误,就是单旋 Splay ,又称 Spaly
这是一种像是 Splay Tree 但却截然不同的数据结构,Spaly 与正版 Splay 只有一个差别:
就是它的 splay
操作的代码时,与原 splay
操作代码相比少了一遍 rotate
,所以称为单旋
Spaly 与 Splay 在时间复杂度上有着天壤之别,Spaly 的复杂度是可以被卡成 \(O(n)\) 的,
但 Splay 的均摊复杂度是 \(O(\log n)\) 的
所以 splay
操作的代码一定要背下来,且不能错
解释一下垃圾回收机制
简而言之,垃圾回收就是为了不浪费内存,将删除的节点编号存进队列里,下次建节点时先考虑循环利用已删除节点的编号的机制
具体代码如下:
queue<int> q;
int newnode()
{
if(q.empty()) return ++tot;
else
{
int ret=q.front();
clear(ret);
q.pop();
return ret;
}
}
代码很简单,运行速度也很快,最关键的是省了不少内存
前驱后继
这两个功能最简单,直接上代码:
int nxt(bool chk)
{
int cur=ch[rt][chk];
if(!cur) return cur;
while(ch[cur][!chk]) cur=ch[cur][!chk];
splay(cur,0);
return cur;
}
就是将当前节点旋转到根,然后搜索左子树的最大值或者右子树的最小值
再将其旋转至根
第k小与排名
排名就直接查找然后将右子树大小加在一起就行了,
第k小的话,就是疯狂减去右子树的大小直到为 \(0\)
int rk(int k)
{
int ret=0,cur=rt;
while(k!=val[cur])
{
if(k<val[cur]) cur=ch[cur][0];
else
{
ret+=sz[ch[cur][0]]+cnt[cur];
cur=ch[cur][1];
}
}
ret+=sz[ch[cur][0]];
splay(cur,0);
return ret+1;
}
int kth(int k)
{
int cur=rt;
while(1)
{
if(ch[cur][0]&&k<=sz[ch[cur][0]]) cur=ch[cur][0];
else
{
k-=cnt[cur]+sz[ch[cur][0]];
if(k<=0)
{
splay(cur,0);
return val[cur];
}
cur=ch[cur][1];
}
}
}
删除与插入
插入就是正常的二叉搜索树的步骤,
删除就是将要删除的节点旋转到根,然后将左右子树合并,最后将将节点放进回收队列里
void erase(int k)
{
rk(k);
if(cnt[rt]>1)
{
--cnt[rt];
update(rt);
return;
}
int cur=rt;
if(!ch[rt][0]&&!ch[rt][1]) rt=0;
else if((!ch[rt][0])||(!ch[rt][1]))
{
rt=ch[rt][0]|ch[rt][1];
fa[rt]=0;
}
else
{
int x=nxt(0);
fa[ch[cur][1]]=x;
ch[x][1]=ch[cur][1];
update(rt);
}
q.push(cur);
}
void insert(int k)
{
if(!rt)
{
rt=newnode();
val[rt]=k;
cnt[rt]=1;
update(rt);
return;
}
int cur=rt,f=0;
while(cur&&val[cur]!=k) f=cur,cur=ch[cur][val[cur]<k];
if(!cur)
{
cur=newnode();
val[cur]=k;
fa[cur]=f;
cnt[cur]=1;
ch[f][val[f]<k]=cur;
}
else if(val[cur]==k) ++cnt[cur];
update(cur),update(f);
splay(cur,0);
}
全部代码
自己看吧
个人觉得还是写得不错的,曾经专门为了提升模板运行速度,减少码量,将各个模板综合成了以下这个模板:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int MAXN=1e7+7;
int rt,tot,fa[MAXN],ch[MAXN][2],sz[MAXN],val[MAXN],cnt[MAXN];
void update(int x){sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];}
#define get(x) x==ch[fa[x]][1]
#define clear(x) fa[x]=ch[x][0]=ch[x][1]=sz[x]=val[x]=cnt[x]=0
queue<int> q;
int newnode()
{
if(q.empty()) return ++tot;
else
{
int ret=q.front();
clear(ret);
q.pop();
return ret;
}
}
#define touch(x,y,chk) ch[fa[x]=y][chk]=x
inline void rotate(int x)
{
bool yck=get(x),zck=get(fa[x]),xck=!yck;
int y=fa[x],z=fa[y],xs=ch[x][xck];
touch(xs,y,yck),touch(x,z,zck),touch(y,x,xck);
update(y),update(x);
}
inline void splay(int x,int y)
{
for(int cur=fa[x]; cur!=y; rotate(x),cur=fa[x])
if(fa[cur]!=y) rotate(get(x)==get(cur)?cur:x);
if(y==0) rt=x;
}
void insert(int k)
{
if(!rt)
{
rt=newnode();
val[rt]=k;
cnt[rt]=1;
update(rt);
return;
}
int cur=rt,f=0;
while(cur&&val[cur]!=k) f=cur,cur=ch[cur][val[cur]<k];
if(!cur)
{
cur=newnode();
val[cur]=k;
fa[cur]=f;
cnt[cur]=1;
ch[f][val[f]<k]=cur;
}
else if(val[cur]==k) ++cnt[cur];
update(cur),update(f);
splay(cur,0);
}
int rk(int k)
{
int ret=0,cur=rt;
while(k!=val[cur])
{
if(k<val[cur]) cur=ch[cur][0];
else
{
ret+=sz[ch[cur][0]]+cnt[cur];
cur=ch[cur][1];
}
}
ret+=sz[ch[cur][0]];
splay(cur,0);
return ret+1;
}
int kth(int k)
{
int cur=rt;
while(1)
{
if(ch[cur][0]&&k<=sz[ch[cur][0]]) cur=ch[cur][0];
else
{
k-=cnt[cur]+sz[ch[cur][0]];
if(k<=0)
{
splay(cur,0);
return val[cur];
}
cur=ch[cur][1];
}
}
}
int nxt(bool chk)
{
int cur=ch[rt][chk];
if(!cur) return cur;
while(ch[cur][!chk]) cur=ch[cur][!chk];
splay(cur,0);
return cur;
}
void erase(int k)
{
rk(k);
if(cnt[rt]>1)
{
--cnt[rt];
update(rt);
return;
}
int cur=rt;
if(!ch[rt][0]&&!ch[rt][1]) rt=0;
else if((!ch[rt][0])||(!ch[rt][1]))
{
rt=ch[rt][0]|ch[rt][1];
fa[rt]=0;
}
else
{
int x=nxt(0);
fa[ch[cur][1]]=x;
ch[x][1]=ch[cur][1];
update(rt);
}
q.push(cur);
}
int n,m,last,ans;
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
int a;
scanf("%d",&a);
insert(a);
}
for(int i=1;i<=m;i++)
{
int opt,x;
scanf("%d%d",&opt,&x);
x^=last;
if(opt==1) insert(x);
else if(opt==2) erase(x);
else if(opt==3) insert(x),last=rk(x),erase(x);
else if (opt==4) last=kth(x);
else if (opt==5) insert(x),last=val[nxt(0)],erase(x);
else insert(x),last=val[nxt(1)],erase(x);
if(opt<=6&&opt>=3) ans^=last;
}
printf("%d",ans);
return 0;
}
这是这道 模板题 的AC代码