[笔记]Splay树
前置知识:树的左旋、右旋。
Splay树是一种平衡树。能够做到每个操作均摊\(O(\log N)\)。
前言
与上文AVL树不同之处在于,AVL树在任何操作结束后,都能保证每个节点的左右子树高度相差不超过\(1\)。相应地,每个操作都是严格的\(O(\log N)\)。而Splay树并没有对“平衡”的确切定义,任何结构的树都可能是Splay树(甚至可以是一条链)。相应地,单次操作的时间复杂度是无法计算的,只能知道每个操作的均摊时间复杂度为\(O(\log N)\)。
Splay树的核心思想在于“Splay操作”。对节点\(x\)进行Splay操作,即通过一系列“zig”、“zag”及其组合操作,将\(x\)挪到根节点的位置(同时对树的结构进行一些调整,使每次操作最坏复杂度更接近\(O(\log N)\))。
增、删、查、改……几乎每一个操作完成后都要跟着一个Splay操作,把操作的节点移到根节点。注意这样不一定能像AVL树那样,保证这棵树每个节点左右子树高度相差不超过\(1\),不过可以证明,这种修改可以让每个操作的均摊时间复杂度变为\(O(\log N)\)。
这样,Splay的优势显现了出来:越是最近操作过的节点,到根节点的距离就越近,访问就会越快。而在实际应用中,我们频繁访问到的,可能只是所有数据的一小部分(比如我们用的输入法)。
具体操作方式见下。
代码来自Splay 详细图解 & 轻量级代码实现 - 樱雪喵,十分感谢!
各种操作
准备
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
#define ch(x,y) tr[x].ch[y]
#define fa(x) tr[x].fa
#define v(x) tr[x].v
#define siz(x) tr[x].siz
struct tree{
int ch[2],siz,fa,v;
void clear(){ch[0]=ch[1]=siz=fa=v=0;}
}tr[N];
int t,cnt,root;
void clear(int u){tr[u].clear();}
int newnode(int v){v(++cnt)=v,siz(cnt)=1;return cnt;}
void update(int u){siz(u)=siz(lc(u))+siz(rc(u))+1;}//更新状态
bool get(int u){return u==rc(fa(u));}//查询u是其父节点的左子结点还是右子节点
与AVL不同,Splay不需要记录树高。
旋转 & Splay操作
void rot(int x){
int y=fa(x),z=fa(y);//保证y!=0
bool dir=get(x),tdir=get(y);
if(ch(x,!dir)) fa(ch(x,!dir))=y;
ch(y,dir)=ch(x,!dir);
ch(x,!dir)=y;
fa(y)=x,fa(x)=z;
if(z) ch(z,tdir)=x;
update(y),update(x);
}
旋转操作把左旋和右旋写在了一起,其含义也较AVL有改变。rot(x)
表示通过旋转将\(x\)提升一层。
接下来讲解算法的精髓——Splay操作。
Splay操作就是不断调用rot(x)
,将节点\(x\)移到根节点的位置。
这里设\(x\)的父节点为\(f\)(下图中是\(p\))。
但旋转方式根据树的结构也有所不同,一共分为\(6\)种:zig
、zig-zig
、zig-zag
、zag
、 zag-zag
、zag-zig
。由于后面\(3\)种与前面\(3\)种操作是对称的,所以就只说前\(3\)种了。
下面\(3\)张图片来自OI Wiki。
1 - zig/zag
zig
/zag
操作,只有Splay过程中的最后一次旋转才可能会使用。
条件就是\(x\)到当前根节点的距离正好是\(1\)。
这种情况下,调用一次rot(x)
即可。
2 - zig-zig/zag-zag
zig-zig
/zag-zag
操作,是在\(x\)到根节点距离超过\(1\),而且\(x,y\)都是左子结点/右子节点的情况下使用的。
这种情况下,需要先调用rot(p)
,再调用rot(x)
。
3 - zig-zag/zag-zig
zig-zag
/zag-zig
操作,是在\(x\)到根节点距离超过\(1\),而且\(x,y\)中\(1\)个是左子节点,\(1\)个是右子节点的情况下使用的。
这种情况下,需要连续调用\(2\)次rot(x)
。
为什么zig-zag
就是两次rot(x)
,而zig-zig
必须先调用rot(p)
再调用rot(x)
呢?
(当然,你可以发现在zig-zag
的情况下是做不到rot(p)
再rot(x)
的)
这是为了保证时间复杂度。当然我们也可以感性理解一下:
如你所见,这是一条链(没错,如果依次插入\(16,15,\dots ,1\)就会这样)。
现在我们想对节点\(16\)执行Splay操作,那显然每次都是zig-zig
了。
出现链不要担心,我们按照正确的操作(rot(p)
再rot(x)
)跑一遍:
我们发现,树的高度的数量级直接减少了一半。如果再对节点\(15\)进行一个Splay操作,树的高度会更加接近\(\log N\)。
但如果我们按错误的操作(两次rot(x)
)来跑:
还是。。。一条链。
代码可以写出来了。
void splay(int x){
for(int f;(f=fa(x));rot(x))
if(fa(f)) rot(get(f)==get(x)?f:x);
root=x;
}
把节点\(x\)通过Splay转移到顶部的过程中,\(x\)以及\(x\)的祖先节点都会通过update()
被更新。其他节点则不会调用到update()
。所以Splay操作后节点信息正确的前提是除了\(x\)和\(x\)的祖先节点,其他节点的信息都应该是正确的。
其他操作
和普通BST一样了(删除有些不同,是先把要删除的节点Splay到顶部,然后删掉它,把右子树连接到左子树的最大权值下面)。。。不过要注意每个操作最后都要加Splay。否则每次操作\(O(\log N)\)复杂度会假(想象一下如果你有个操作没加Splay,又正好碰到了上面那条链的样例,它又正好不断调用这个操作,那么这条链就一直得不到调整,每次时间复杂度就是\(O(n)\)了。。。)。
插入
void ins(int v){
int u=root,f=0;
while(u) f=u,u=ch(u,v>v(u));//>在右,<=在左
u=newnode(v),fa(u)=f;
if(f) ch(f,v>v(f))=u;
splay(u);
}
删除
void del(int v){
int u=root,f=0;
while(v(u)!=v&&u) f=u,u=ch(u,v>v(u));//找到要修改的节点
if(!u){splay(f);return;}//不要忘记Splay!
splay(u);//这里Splay到根是为了更方便操作,不用特判根了
int pre=lc(u);
if(!pre){root=rc(u),fa(rc(u))=0,clear(u);return;}
while(rc(pre)) pre=rc(pre);
rc(pre)=rc(u),fa(rc(u))=pre,fa(lc(u))=0,clear(u);
splay(pre);//这一次Splay可以解决祖先siz没更新的问题
}
查询\(v\)的排名
int getrank(int v){
int u=root,ran=1,f=0;
while(u){
f=u;
if(v>v(u)) ran+=siz(lc(u))+1,u=rc(u);//Right
else u=lc(u);//Left
}
splay(f);
return ran;
}
查询排名为\(ran\)的值
int getnum(int ran){
int u=root;
while(u){
int sz=siz(lc(u))+1;
if(sz==ran) break;
else if(sz>ran) u=lc(u);
else ran-=sz,u=rc(u);
}
splay(u);
return v(u);
}
前驱
int pre(int u){return getnum(getrank(u)-1);}
后继
int nex(int u){return getnum(getrank(u+1));}
时间复杂度证明
先。。。咕掉吧
模板题 & Code
点击查看代码
#include<bits/stdc++.h>
#define int long long
#define N 100010
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
#define ch(x,y) tr[x].ch[y]
#define fa(x) tr[x].fa
#define v(x) tr[x].v
#define siz(x) tr[x].siz
using namespace std;
struct tree{
int ch[2],siz,fa,v;
void clear(){ch[0]=ch[1]=siz=fa=v=0;}
}tr[N];
int t,cnt,root;
void clear(int u){tr[u].clear();}
int newnode(int v){v(++cnt)=v,siz(cnt)=1;return cnt;}
void update(int u){siz(u)=siz(lc(u))+siz(rc(u))+1;}
bool get(int u){return u==rc(fa(u));}
void rot(int x){
int y=fa(x),z=fa(y);//保证y!=0
bool dir=get(x),tdir=get(y);
if(ch(x,!dir)) fa(ch(x,!dir))=y;
ch(y,dir)=ch(x,!dir);
ch(x,!dir)=y;
fa(y)=x,fa(x)=z;
if(z) ch(z,tdir)=x;
update(y),update(x);
}
void splay(int x){
for(int f;(f=fa(x));rot(x))
if(fa(f)) rot(get(f)==get(x)?f:x);
root=x;
}
void ins(int v){
int u=root,f=0;
while(u) f=u,u=ch(u,v>v(u));//>在右,<=在左
u=newnode(v),fa(u)=f;
if(f) ch(f,v>v(f))=u;
splay(u);
}
void del(int v){
int u=root,f=0;
while(v(u)!=v&&u) f=u,u=ch(u,v>v(u));
if(!u){splay(f);return;}
splay(u);
int pre=lc(u);
if(!pre){root=rc(u),fa(rc(u))=0,clear(u);return;}
while(rc(pre)) pre=rc(pre);
rc(pre)=rc(u),fa(rc(u))=pre,fa(lc(u))=0,clear(u);
splay(pre);
}
int getrank(int v){
int u=root,ran=1,f=0;
while(u){
f=u;
if(v>v(u)) ran+=siz(lc(u))+1,u=rc(u);//Right
else u=lc(u);//Left
}
splay(f);
return ran;
}
int getnum(int ran){
int u=root;
while(u){
int sz=siz(lc(u))+1;
if(sz==ran) break;
else if(sz>ran) u=lc(u);
else ran-=sz,u=rc(u);
}
splay(u);
return v(u);
}
int pre(int u){return getnum(getrank(u)-1);}
int nex(int u){return getnum(getrank(u+1));}
signed main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>t;
while(t--){
int op,x;
cin>>op>>x;
if(op==1) ins(x);
else if(op==2) del(x);
else if(op==3) cout<<getrank(x)<<"\n";
else if(op==4) cout<<getnum(x)<<"\n";
else if(op==5) cout<<pre(x)<<"\n";
else if(op==6) cout<<nex(x)<<"\n";
}
return 0;
}
附:相同节点合并代码
也可以把值相同的节点合并为\(1\)个,用\(cnt\)记一下数。
然后newnode
、update
、ins
、del
、getrank
、getnum
函数需要做相应的修改。具体见代码。
#include<bits/stdc++.h>
#define int long long
#define N 100010
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
#define ch(x,y) tr[x].ch[y]
#define fa(x) tr[x].fa
#define v(x) tr[x].v
#define siz(x) tr[x].siz
#define cnt(x) tr[x].cnt
using namespace std;
struct tree{
int ch[2],siz,fa,v,cnt;
void clear(){ch[0]=ch[1]=siz=fa=v=0;}
}tr[N];
int t,cnt,root;
void clear(int u){tr[u].clear();}
int newnode(int v){v(++cnt)=v,siz(cnt)=1,cnt(cnt)=1;return cnt;}
void update(int u){siz(u)=siz(lc(u))+siz(rc(u))+cnt(u);}
bool get(int u){return u==rc(fa(u));}
void rot(int x){
int y=fa(x),z=fa(y);//保证y!=0
bool dir=get(x),tdir=get(y);
if(ch(x,!dir)) fa(ch(x,!dir))=y;
ch(y,dir)=ch(x,!dir);
ch(x,!dir)=y;
fa(y)=x,fa(x)=z;
if(z) ch(z,tdir)=x;
update(y),update(x);
}
void splay(int x){
for(int f;(f=fa(x));rot(x))
if(fa(f)) rot(get(f)==get(x)?f:x);
root=x;
}
void ins(int v){
int u=root,f=0;
while(u&&v(u)!=v) f=u,u=ch(u,v>v(u));//>在右,<=在左
if(u) cnt(u)++;
else{
u=newnode(v),fa(u)=f;
if(f) ch(f,v>v(f))=u;
}
splay(u);
}
void del(int v){
int u=root,f=0;
while(v(u)!=v&&u) f=u,u=ch(u,v>v(u));
if(!u){splay(f);return;}
splay(u);
if(cnt(u)>1){cnt(u)--;return;}
int pre=lc(u);
if(!pre){root=rc(u),fa(rc(u))=0,clear(u);return;}
while(rc(pre)) pre=rc(pre);
rc(pre)=rc(u),fa(rc(u))=pre,fa(lc(u))=0,clear(u);
splay(pre);
}
int getrank(int v){
int u=root,ran=1,f=0;
while(u){
f=u;
if(v>v(u)) ran+=siz(lc(u))+cnt(u),u=rc(u);//Right
else u=lc(u);//Left
}
splay(f);
return ran;
}
int getnum(int ran){
int u=root;
while(u){
int sz=siz(lc(u))+cnt(u);
if(ran<=siz(lc(u))) u=lc(u);
else if(ran>siz(lc(u))+cnt(u)) ran-=sz,u=rc(u);
else break;//如果ran在[siz[l]+1,siz[l]+cnt[u]]的区间内,就说明第ran名就是u
}
splay(u);
return v(u);
}
int pre(int u){return getnum(getrank(u)-1);}
int nex(int u){return getnum(getrank(u+1));}
signed main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin>>t;
while(t--){
int op,x;
cin>>op>>x;
if(op==1) ins(x);
else if(op==2) del(x);
else if(op==3) cout<<getrank(x)<<"\n";
else if(op==4) cout<<getnum(x)<<"\n";
else if(op==5) cout<<pre(x)<<"\n";
else if(op==6) cout<<nex(x)<<"\n";
}
return 0;
}