普通平衡树学习笔记之Splay算法
前言
今天不容易有一天的自由学习时间,当然要用来“学习”。在此记录一下今天学到的最基础的平衡树。
定义
平衡树是二叉搜索树和堆合并构成的数据结构,它是一 棵空树或它的左右两个子树的高度差的绝对值不超过1,并且左右两个子树都是一棵平衡二叉树。
这里仅仅说明一下平衡树中的\(Splay\)算法
进入正题
平衡树中有许多种类:红黑树、\(AVL\)树,伸展树(\(Splay\)),\(Treap\)等等,但是\(Splay\)算法算是可用性很强的一种了。也就是说比较稳定。
在\(Splay\)算法中,一个处处都要用到的东西就是旋转,即将当前节点与其前边一个节点依次旋转到目标位置。由于这个树是一个二叉搜索树,所以旋转之后要保证性质不变。我们就需要找到当前节点的父亲和爷爷节点,然后先更新爷爷节点与当前节点之间的关系,然后将父亲节点与该节点所属关系的另一个子树连起来,最后再处理一下该节点和父亲节点的关系。
我们举个例子(毕竟不太好理解)我们设这三个点分别为\(x,y,z\),从左往右分别是后一个的儿子,我们先把图画一下(\(x\)的子节点我随便用一堆东西表示):
在这里\(x\)是\(y\)的左节点。那么我们下一步就是把\(x\)变为\(z\)的子节点,也就是把\(y\)换下来。这一步很简单,变成了下边这样:
然后就是关键了,因为把\(x\)旋转上来,\(x\)也是有子节点的,所以我们需要进一步处理。首先记录一下\(x\)是\(y\)节点的左还是右儿子,在这个例子里是左儿子,因为我们不能破坏这个树的顺序和性质,所以就需要让\(y\)的右儿子不变且所有满足性质的比\(x\)小的他的左儿子还是\(x\)的左儿子,而\(y\)的左儿子变成\(x\)的右儿子,也就是下边这个图:(刚刚没有\(y\)的右节点,现在加上,编号与之前的不同,自己理解理解\(qwq\)):
类似的一个图,这样就完成了一次旋转。最后不样忘记也是需要\(pushup\)的,来维护点的儿子数量和自己的数量。
下边放一下\(Splay\)和旋转的代码.
void rotate(int x){//旋转
int y=t[x].fa;
int z=t[y].fa;
int k=t[y].ch[1]==x;//找到x是y的左还是右节点,便于进行上文中所说操作
t[z].ch[t[z].ch[1]==y]=x;//以下都是上边说过的操作。
t[x].fa=z;
t[y].ch[k]=t[x].ch[k^1];
t[t[x].ch[k^1]].fa=y;
t[x].ch[k^1]=y;
t[y].fa=x;
pushup(y);
pushup(x);
}
void splay(int x,int goal){//旋转
while(t[x].fa!=goal){//父亲节点不是目标
int y=t[x].fa;
int z=t[y].fa;
if(z!=goal){//爷爷节点也不是目标
(t[y].ch[0]==x)^(t[z].ch[0]==y)?rotate(x):rotate(y);//y左儿子是x就翻转x,不然翻转y(因为树有序,不能破坏性质)
}//上边这句话还需要细细钻研,本题解以后还会完善
rotate(x);
}
if(goal == 0){//到了根节点的话让根节点更新
root = x;
}
}
一些操作
上边把\(Splay\)的操作说了一遍,也就是关键的旋转应该比较清晰了,下边我们就需要进行一些操作了。
平衡树有许多可以进行的操作:删除,插入,查询一个值\(x\)的排名,查询\(x\)排名的值,还有值\(x\)的前驱和后继。(前驱定义为小于 \(x\),且最大的数,后继定义为大于 \(x\)且最小的数)。这些都需要上边的旋转,所以旋转明白了,这些也就比较容易了。首先是结构体的定义:
struct Node{
int ch[2];//子节点
int fa;//父节点
int cnt;//数量
int val;//值
int son;//儿子数量
}t[maxn];
1、插入
我们肯定要从根节点开始向下找到符合这个值的点,或者新建一个点,那么我们首先另\(u\)为根节点,其父亲为\(0\),然后如果有根且插入的值不是当前节点的值,那么我们就需要向子节点扩展,这里我的扩展比较巧妙,因为儿子只有左右分别用\(0\)、\(1\)表示,所以直接用判断来找到底是左还是右,也就是这样的式子:\(x>t[u].val\)。如果大于的话,就是右儿子。否则就是左儿子。所以直接更新到当前节点的儿子节点。
当跳出向下扩展的循环,就说明当前点值就是要插入的值,我们只需要将大小加一就行。如果没有这样的节点,那么就新建一个节点,然后将父亲的儿子节点置为当前节点,而左右儿子的判断如上文所说。其余的东西都是初始化,具体看代码,最后不要忘了再将当前节点\(Splay\)到根(几乎每种操作都需要\(Splay\),查询前驱后继和排名的值不用):
void Add(int x){
int u=root,fa=0;
while(u&&t[u].val!=x){
fa=u;
u=t[u].ch[x>t[u].val];
}
if(u){//有的话直接大小加一
t[u].cnt++;
}
else{//没有该值的节点
u=++tot;
if(fa)t[fa].ch[x>t[fa].val]=u;
t[tot].ch[1]=0;
t[tot].ch[0]=0;
t[tot].val=x;
t[tot].fa=fa;
t[tot].cnt=1;
t[tot].son=1;
}
splay(u,0);
}
2、查找值为x的排名:
根据这个判断\(x>t[u].val\)依次找\(x\)的位置,最后\(Splay\)一下就好了:
void Find(int x){
int u=root;//从根开始
if(!u)return;//没有树就直接跳出
while(t[u].ch[x>t[u].val] && x!=t[u].val){//依次向下找到当前值的点
u=t[u].ch[x>t[u].val];//更新
}
splay(u,0);//旋转
}
这个到最后把这个位置\(Splay\)到了根,所以答案就是当前\(Find\)之后根的左儿子的儿子数,注意如果根节点的值小于\(x\),要加上根节点的数量(\(diss\ YYB\))。
3、求前驱后继(这个操作求出来的是节点编号)
我们首先需要找到排名,也就是操作\(2\),然后\(Splay\)到根节点,如果值大于当前值且查找的是后继或者小于当前且找前驱就直接返回,否则就向子节点转移。找到转移后最接近当前值的点,也就是说,如果第一次不满足,假如找前驱就向左走一个,然后找到左儿子的右节点的最下边的点,也就是最接近这个查找的值的点。
int Fr_last(int x,int flag){//前驱flag为0,后继为1
Find(x);//找到位置
int u=root;//根开始
if((t[u].val>x&&flag) || (t[u].val<x && !flag)){//当前点满足就直接返回
return u;
}
u=t[u].ch[flag];//向目标点(左或右)转移
while(t[u].ch[flag^1])u=t[u].ch[flag^1];//找到转移后最接近当前值的点
return u;//返回
}
4、查找排名为x的值
首先从根节点开始,如果一共都没有\(x\)个数,那么就直接返回\(0\),不然的话就分别记录一下当前点的左右节点,然后判断,如果当前点的子节点树加上当前点的值的数量小于查找的排名,直接减去然后走到右儿子,不然就走到左儿子就行了。
int Find_thval(int x){
int u=root;//根开始
if(t[u].son<x){//如果没有这么多,直接返回0
return 0;
}
while(666666){//一直循环
int y=t[u].ch[0];//记录左儿子
if(x>t[y].son+t[u].cnt){//排名大就减去,然后走到右节点
x-=t[y].son+t[u].cnt;
u=t[u].ch[1];
}
else{
if(x<=t[y].son){//否则走到左节点
u=y;
}
else return t[u].val;//如果排名比上边的小,且比左节点的值大,这就是满足的价值,直接返回
}
}
}
我的这个代码有一些等号的取舍不同,所以在查找的时候传递参数需要加上一。
5、删除
我们需要首先找出这个点的前驱和后继,然后旋转下去,要删除的就是后继的左儿子,假如这个点的数量大于\(1\),就直接数量减一就好了,然后翻转到根节点,如果小于等于\(1\),那么就把这个点变成\(0\),结束!
void Delete(int x){
int Front=Fr_last(x,0);//前驱
int Last=Fr_last(x,1);//后继
splay(Front,0);//旋转
splay(Last,Front);
int del=t[Last].ch[0];//找到需要删除的点
if(t[del].cnt>1){//大于1直接减
t[del].cnt--;
splay(del,0);
}
else{//否则直接删除
t[Last].ch[0]=0;
}
}
总结
以上就是\(Splay\)的一些实现和操作,以后博客还会进行修改和完善,这些只是暂时自学时的理解,如果有神犇能给蒟蒻一些指导那就更好了。
完结撒花\(qwqq\)。
下边推荐一个板子题普通平衡树板子
板子题代码:
细节的注释上边都写过了,祝愿大家学习愉快\(qwq\)。
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e5+10;
const int Inf = 2147483647;
struct Node{
int son,ch[2],fa,cnt,val;
}t[maxn];
int n,tot,root;
void pushup(int x){
t[x].son = t[t[x].ch[0]].son+t[t[x].ch[1]].son+t[x].cnt;
}
void rotate(int x){
int y=t[x].fa;
int z=t[y].fa;
int k=t[y].ch[1]==x;
t[z].ch[t[z].ch[1]==y]=x;
t[x].fa=z;
t[y].ch[k]=t[x].ch[k^1];
t[t[x].ch[k^1]].fa=y;
t[x].ch[k^1]=y;
t[y].fa=x;
pushup(y);
pushup(x);
}
void splay(int x,int goal){
while(t[x].fa!=goal){
int y=t[x].fa;
int z=t[y].fa;
if(z!=goal){
(t[y].ch[0]==x)^(t[z].ch[0]==y)?rotate(x):rotate(y);
}
rotate(x);
}
if(goal == 0){
root = x;
}
}
void Find(int x){
int u=root;
if(!u)return;
while(t[u].ch[x>t[u].val] && x!=t[u].val){
u=t[u].ch[x>t[u].val];
}
splay(u,0);
}
void Add(int x){
int u=root,fa=0;
while(u&&t[u].val!=x){
fa=u;
u=t[u].ch[x>t[u].val];
}
if(u){
t[u].cnt++;
}
else{
u=++tot;
if(fa)t[fa].ch[x>t[fa].val]=u;
t[tot].ch[1]=0;
t[tot].ch[0]=0;
t[tot].val=x;
t[tot].fa=fa;
t[tot].cnt=1;
t[tot].son=1;
}
splay(u,0);
}
int Fr_last(int x,int flag){
Find(x);
int u=root;
if((t[u].val>x&&flag) || (t[u].val<x && !flag)){
return u;
}
u=t[u].ch[flag];
while(t[u].ch[flag^1])u=t[u].ch[flag^1];
return u;
}
void Delete(int x){
int Front=Fr_last(x,0);
int Last=Fr_last(x,1);
splay(Front,0);
splay(Last,Front);
int del=t[Last].ch[0];
if(t[del].cnt>1){
t[del].cnt--;
splay(del,0);
}
else{
t[Last].ch[0]=0;
}
}
int Find_thval(int x){
int u=root;
if(t[u].son<x){
return 0;
}
while(666666){
int y=t[u].ch[0];
if(x>t[y].son+t[u].cnt){
x-=t[y].son+t[u].cnt;
u=t[u].ch[1];
}
else{
if(x<=t[y].son){
u=y;
}
else return t[u].val;
}
}
}
int main(){
int n;
Add(Inf);
Add(-Inf);
scanf("%d",&n);
while(n--){
int opt;
scanf("%d",&opt);
int x;
if(opt == 1){
scanf("%d",&x);
Add(x);
}
if(opt == 2){
scanf("%d",&x);
Delete(x);
}
if(opt == 3){
scanf("%d",&x);
Find(x);
int ans = t[t[root].ch[0]].son + (t[root].val < x ? t[root].cnt : 0);
printf("%d\n",ans);
}
if(opt == 4){
int ans;
scanf("%d",&x);
ans = Find_thval(x+1);
printf("%d\n",ans);
}
if(opt == 5){
scanf("%d",&x);
int ans = Fr_last(x,0);
printf("%d\n",t[ans].val);
}
if(opt == 6){
scanf("%d",&x);
int ans = Fr_last(x,1);
printf("%d\n",t[ans].val);
}
}
return 0;
}