[笔记]Splay树
前置知识:树的左旋、右旋。
Splay树是一种平衡树。能够做到每个操作均摊\(O(\log N)\)。
前言
与上文AVL树不同之处在于,AVL树在任何操作结束后,都能保证每个节点的左右子树高度相差不超过\(1\)。相应地,每个操作都是严格的\(O(\log N)\)。而Splay树并没有对“平衡”的确切定义,任何结构的树都可能是Splay树(甚至可以是一条链)。相应地,单次操作的时间复杂度是无法计算的,只能知道每个操作的均摊时间复杂度为\(O(\log N)\)。
Splay树的核心思想在于“Splay操作”。对节点\(x\)进行Splay操作,即通过一系列“zig”、“zag”及其组合操作,将\(x\)挪到根节点的位置(同时对树的结构进行一些调整,使每次操作最坏复杂度更接近\(O(\log N)\))。
增、删、查、改……几乎每一个操作完成后都要跟着一个Splay操作,把操作的节点移到根节点。注意这样不一定能像AVL树那样,保证这棵树每个节点左右子树高度相差不超过\(1\),不过可以证明,这种修改可以让每个操作的均摊时间复杂度变为\(O(\log N)\)。
这样,Splay的优势显现了出来:越是最近操作过的节点,到根节点的距离就越近,访问就会越快。而在实际应用中,我们频繁访问到的,可能只是所有数据的一小部分(比如我们用的输入法)。
具体操作方式见下。
代码来自Splay 详细图解 & 轻量级代码实现 - 樱雪喵,十分感谢!
各种操作
准备
#define lc(x) tr[x].ch[0] #define rc(x) tr[x].ch[1] #define ch(x,y) tr[x].ch[y] #define fa(x) tr[x].fa #define v(x) tr[x].v #define siz(x) tr[x].siz struct tree{ int ch[2],siz,fa,v; void clear(){ch[0]=ch[1]=siz=fa=v=0;} }tr[N]; int t,cnt,root; void clear(int u){tr[u].clear();} int newnode(int v){v(++cnt)=v,siz(cnt)=1;return cnt;} void update(int u){siz(u)=siz(lc(u))+siz(rc(u))+1;}//更新状态 bool get(int u){return u==rc(fa(u));}//查询u是其父节点的左子结点还是右子节点
与AVL不同,Splay不需要记录树高。
旋转 & Splay操作
void rot(int x){ int y=fa(x),z=fa(y);//保证y!=0 bool dir=get(x),tdir=get(y); if(ch(x,!dir)) fa(ch(x,!dir))=y; ch(y,dir)=ch(x,!dir); ch(x,!dir)=y; fa(y)=x,fa(x)=z; if(z) ch(z,tdir)=x; update(y),update(x); }
旋转操作把左旋和右旋写在了一起,其含义也较AVL有改变。rot(x)
表示通过旋转将\(x\)提升一层。
接下来讲解算法的精髓——Splay操作。
Splay操作就是不断调用rot(x)
,将节点\(x\)移到根节点的位置。
这里设\(x\)的父节点为\(f\)(下图中是\(p\))。
但旋转方式根据树的结构也有所不同,一共分为\(6\)种:zig
、zig-zig
、zig-zag
、zag
、 zag-zag
、zag-zig
。由于后面\(3\)种与前面\(3\)种操作是对称的,所以就只说前\(3\)种了。
下面\(3\)张图片来自OI Wiki。
1 - zig/zag
zig
/zag
操作,只有Splay过程中的最后一次旋转才可能会使用。
条件就是\(x\)到当前根节点的距离正好是\(1\)。
这种情况下,调用一次rot(x)
即可。
2 - zig-zig/zag-zag
zig-zig
/zag-zag
操作,是在\(x\)到根节点距离超过\(1\),而且\(x,y\)都是左子结点/右子节点的情况下使用的。
这种情况下,需要先调用rot(p)
,再调用rot(x)
。
3 - zig-zag/zag-zig
zig-zag
/zag-zig
操作,是在\(x\)到根节点距离超过\(1\),而且\(x,y\)中\(1\)个是左子节点,\(1\)个是右子节点的情况下使用的。
这种情况下,需要连续调用\(2\)次rot(x)
。
为什么zig-zag
就是两次rot(x)
,而zig-zig
必须先调用rot(p)
再调用rot(x)
呢?
(当然,你可以发现在zig-zag
的情况下是做不到rot(p)
再rot(x)
的)
这是为了保证时间复杂度。当然我们也可以感性理解一下:
如你所见,这是一条链(没错,如果依次插入\(16,15,\dots ,1\)就会这样)。
现在我们想对节点\(16\)执行Splay操作,那显然每次都是zig-zig
了。
出现链不要担心,我们按照正确的操作(rot(p)
再rot(x)
)跑一遍:
我们发现,树的高度的数量级直接减少了一半。如果再对节点\(15\)进行一个Splay操作,树的高度会更加接近\(\log N\)。
但如果我们按错误的操作(两次rot(x)
)来跑:
还是。。。一条链。
代码可以写出来了。
void splay(int x){ for(int f;(f=fa(x));rot(x)) if(fa(f)) rot(get(f)==get(x)?f:x); root=x; }
把节点\(x\)通过Splay转移到顶部的过程中,\(x\)以及\(x\)的祖先节点都会通过update()
被更新。其他节点则不会调用到update()
。所以Splay操作后节点信息正确的前提是除了\(x\)和\(x\)的祖先节点,其他节点的信息都应该是正确的。
其他操作
和普通BST一样了(删除有些不同,是先把要删除的节点Splay到顶部,然后删掉它,把右子树连接到左子树的最大权值下面)。。。不过要注意每个操作最后都要加Splay。否则每次操作\(O(\log N)\)复杂度会假(想象一下如果你有个操作没加Splay,又正好碰到了上面那条链的样例,它又正好不断调用这个操作,那么这条链就一直得不到调整,每次时间复杂度就是\(O(n)\)了。。。)。
插入
void ins(int v){ int u=root,f=0; while(u) f=u,u=ch(u,v>v(u));//>在右,<=在左 u=newnode(v),fa(u)=f; if(f) ch(f,v>v(f))=u; splay(u); }
删除
void del(int v){ int u=root,f=0; while(v(u)!=v&&u) f=u,u=ch(u,v>v(u));//找到要修改的节点 if(!u){splay(f);return;}//不要忘记Splay! splay(u);//这里Splay到根是为了更方便操作,不用特判根了 int pre=lc(u); if(!pre){root=rc(u),fa(rc(u))=0,clear(u);return;} while(rc(pre)) pre=rc(pre); rc(pre)=rc(u),fa(rc(u))=pre,fa(lc(u))=0,clear(u); splay(pre);//这一次Splay可以解决祖先siz没更新的问题 }
查询\(v\)的排名
int getrank(int v){ int u=root,ran=1,f=0; while(u){ f=u; if(v>v(u)) ran+=siz(lc(u))+1,u=rc(u);//Right else u=lc(u);//Left } splay(f); return ran; }
查询排名为\(ran\)的值
int getnum(int ran){ int u=root; while(u){ int sz=siz(lc(u))+1; if(sz==ran) break; else if(sz>ran) u=lc(u); else ran-=sz,u=rc(u); } splay(u); return v(u); }
前驱
int pre(int u){return getnum(getrank(u)-1);}
后继
int nex(int u){return getnum(getrank(u+1));}
时间复杂度证明
先。。。咕掉吧
模板题 & Code
点击查看代码
#include<bits/stdc++.h> #define int long long #define N 100010 #define lc(x) tr[x].ch[0] #define rc(x) tr[x].ch[1] #define ch(x,y) tr[x].ch[y] #define fa(x) tr[x].fa #define v(x) tr[x].v #define siz(x) tr[x].siz using namespace std; struct tree{ int ch[2],siz,fa,v; void clear(){ch[0]=ch[1]=siz=fa=v=0;} }tr[N]; int t,cnt,root; void clear(int u){tr[u].clear();} int newnode(int v){v(++cnt)=v,siz(cnt)=1;return cnt;} void update(int u){siz(u)=siz(lc(u))+siz(rc(u))+1;} bool get(int u){return u==rc(fa(u));} void rot(int x){ int y=fa(x),z=fa(y);//保证y!=0 bool dir=get(x),tdir=get(y); if(ch(x,!dir)) fa(ch(x,!dir))=y; ch(y,dir)=ch(x,!dir); ch(x,!dir)=y; fa(y)=x,fa(x)=z; if(z) ch(z,tdir)=x; update(y),update(x); } void splay(int x){ for(int f;(f=fa(x));rot(x)) if(fa(f)) rot(get(f)==get(x)?f:x); root=x; } void ins(int v){ int u=root,f=0; while(u) f=u,u=ch(u,v>v(u));//>在右,<=在左 u=newnode(v),fa(u)=f; if(f) ch(f,v>v(f))=u; splay(u); } void del(int v){ int u=root,f=0; while(v(u)!=v&&u) f=u,u=ch(u,v>v(u)); if(!u){splay(f);return;} splay(u); int pre=lc(u); if(!pre){root=rc(u),fa(rc(u))=0,clear(u);return;} while(rc(pre)) pre=rc(pre); rc(pre)=rc(u),fa(rc(u))=pre,fa(lc(u))=0,clear(u); splay(pre); } int getrank(int v){ int u=root,ran=1,f=0; while(u){ f=u; if(v>v(u)) ran+=siz(lc(u))+1,u=rc(u);//Right else u=lc(u);//Left } splay(f); return ran; } int getnum(int ran){ int u=root; while(u){ int sz=siz(lc(u))+1; if(sz==ran) break; else if(sz>ran) u=lc(u); else ran-=sz,u=rc(u); } splay(u); return v(u); } int pre(int u){return getnum(getrank(u)-1);} int nex(int u){return getnum(getrank(u+1));} signed main(){ ios::sync_with_stdio(false); cin.tie(nullptr); cin>>t; while(t--){ int op,x; cin>>op>>x; if(op==1) ins(x); else if(op==2) del(x); else if(op==3) cout<<getrank(x)<<"\n"; else if(op==4) cout<<getnum(x)<<"\n"; else if(op==5) cout<<pre(x)<<"\n"; else if(op==6) cout<<nex(x)<<"\n"; } return 0; }
附:相同节点合并代码
也可以把值相同的节点合并为\(1\)个,用\(cnt\)记一下数。
然后newnode
、update
、ins
、del
、getrank
、getnum
函数需要做相应的修改。具体见代码。
#include<bits/stdc++.h> #define int long long #define N 100010 #define lc(x) tr[x].ch[0] #define rc(x) tr[x].ch[1] #define ch(x,y) tr[x].ch[y] #define fa(x) tr[x].fa #define v(x) tr[x].v #define siz(x) tr[x].siz #define cnt(x) tr[x].cnt using namespace std; struct tree{ int ch[2],siz,fa,v,cnt; void clear(){ch[0]=ch[1]=siz=fa=v=0;} }tr[N]; int t,cnt,root; void clear(int u){tr[u].clear();} int newnode(int v){v(++cnt)=v,siz(cnt)=1,cnt(cnt)=1;return cnt;} void update(int u){siz(u)=siz(lc(u))+siz(rc(u))+cnt(u);} bool get(int u){return u==rc(fa(u));} void rot(int x){ int y=fa(x),z=fa(y);//保证y!=0 bool dir=get(x),tdir=get(y); if(ch(x,!dir)) fa(ch(x,!dir))=y; ch(y,dir)=ch(x,!dir); ch(x,!dir)=y; fa(y)=x,fa(x)=z; if(z) ch(z,tdir)=x; update(y),update(x); } void splay(int x){ for(int f;(f=fa(x));rot(x)) if(fa(f)) rot(get(f)==get(x)?f:x); root=x; } void ins(int v){ int u=root,f=0; while(u&&v(u)!=v) f=u,u=ch(u,v>v(u));//>在右,<=在左 if(u) cnt(u)++; else{ u=newnode(v),fa(u)=f; if(f) ch(f,v>v(f))=u; } splay(u); } void del(int v){ int u=root,f=0; while(v(u)!=v&&u) f=u,u=ch(u,v>v(u)); if(!u){splay(f);return;} splay(u); if(cnt(u)>1){cnt(u)--;return;} int pre=lc(u); if(!pre){root=rc(u),fa(rc(u))=0,clear(u);return;} while(rc(pre)) pre=rc(pre); rc(pre)=rc(u),fa(rc(u))=pre,fa(lc(u))=0,clear(u); splay(pre); } int getrank(int v){ int u=root,ran=1,f=0; while(u){ f=u; if(v>v(u)) ran+=siz(lc(u))+cnt(u),u=rc(u);//Right else u=lc(u);//Left } splay(f); return ran; } int getnum(int ran){ int u=root; while(u){ int sz=siz(lc(u))+cnt(u); if(ran<=siz(lc(u))) u=lc(u); else if(ran>siz(lc(u))+cnt(u)) ran-=sz,u=rc(u); else break;//如果ran在[siz[l]+1,siz[l]+cnt[u]]的区间内,就说明第ran名就是u } splay(u); return v(u); } int pre(int u){return getnum(getrank(u)-1);} int nex(int u){return getnum(getrank(u+1));} signed main(){ ios::sync_with_stdio(false); cin.tie(nullptr); cin>>t; while(t--){ int op,x; cin>>op>>x; if(op==1) ins(x); else if(op==2) del(x); else if(op==3) cout<<getrank(x)<<"\n"; else if(op==4) cout<<getnum(x)<<"\n"; else if(op==5) cout<<pre(x)<<"\n"; else if(op==6) cout<<nex(x)<<"\n"; } return 0; }
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效