学习笔记-平衡树
BST
平衡树是为了维护二叉搜索树的“平衡”而存在的
BST(二叉搜索树)的定义如下:
-
空树是二叉搜索树。
-
若二叉搜索树的左子树不为空,则其左子树上所有点的附加权值均小于其根节点的值。
-
若二叉搜索树的右子树不为空,则其右子树上所有点的附加权值均大于其根节点的值。
-
二叉搜索树的左右子树均为二叉搜索树。
可以发现,BST的基本操作(插入,删除,查找)的时间复杂度最优为 \(O(\log n)\) 最差为 \(O(n)\)(一条链)
为了降低时间复杂度,所以让树尽可能“平衡”,即层数差异尽量减小
Splay
Splay树是通过旋转操作和Splay操作维持树的平衡
旋转
旋转操作,即在不破坏BST本身性质的情况下使某节点上移一个位置
点击查看代码
//基础操作:更新节点siz值;判断是哪个儿子
void update(int x){
siz[x]=siz[lc(x)]+siz[rc(x)]+cnt[x];
}
bool get(int x){
return x==rc(fa[x]);
}
void rotate(int x){
int f=fa[x],gf=fa[f],son=get(x),son1=get(f);
ch[f][son]=ch[x][son^1];
if(ch[x][son^1]) fa[ch[x][son^1]]=f;
ch[x][son^1]=f;
fa[f]=x;fa[x]=gf;
if(gf) ch[gf][son1]=x;
update(x);update(f);
}
Splay
Splay操作,即把某一结点一直旋转到根节点,这一操作维护了树的平衡,同时可以在做完splay之后直接对根节点进行操作
为了使树平衡,旋转顺序有3种情况
点击查看代码
void splay(int x){
for(int i=fa[x];i=fa[x],i;rotate(x))
if(fa[i]) rotate(get(x)==get(i)?i:x);
// 三种情况:
// 第一种:父亲是根,直接旋转x即可
// 第二种:父亲,祖父和x共线,为避免破坏链结构,需要先转父亲再转x,可以起到对单链的折叠作用,使得单链变成第三种情况
// 第三种:父亲,祖父和x不共线,直接旋转即可,可以减少层数
rt=x;
}
查找
和BST一样
插入
如果没有根节点,直接插入,否则按查找的形式向下走,遇到等于的cnt加1,走到空的新增节点
点击查看代码
void ins(int val){
// 如果没有根节点,直接插入
if(!rt){
key[++tot]=val;
++cnt[tot];
rt=tot;
update(rt);
return;
}
// 否则按查找的形式向下走,遇到等于的cnt加1,走到空的新增节点
int pos=rt,pre=0;
while(true){
if(key[pos]==val){
cnt[pos]++;
update(pos);
update(fa[pos]);
splay(pos);
return;
}
pre=pos;
pos=ch[pos][key[pos]<val];
if(!pos){
key[++tot]=val;
++cnt[tot];
fa[tot]=pre;
ch[pre][key[pre]<val]=tot;
update(tot);
update(pre);
splay(tot);
return;
}
}
// cout<<"ins_out:"<<val<<endl;
}
查询k的排名
把小于k的子树大小加和
点击查看代码
int query_rk(int val){
int pos=rt,ans=0;
while(true){
if(val<key[pos]) pos=ch[pos][0];
else{
ans+=siz[ch[pos][0]];
if(val==key[pos]){
splay(pos);
return ans+1;
}
ans+=cnt[pos];
pos=ch[pos][1];
}
}
}
查询排名为k的值
和权值线段树类似
点击查看代码
int query_kth(int val){
// cout<<"query_kth:"<<val<<endl;
int pos=rt;
while(true){
if(ch[pos][0]&&val<=siz[ch[pos][0]]) pos=ch[pos][0];
else{
val-=cnt[pos]+siz[ch[pos][0]];
if(val<=0){
splay(pos);
return key[pos];
}
pos=ch[pos][1];
}
}
}
删除
分类讨论
点击查看代码
void del(int val){
// 先把x旋转到根的位置,如果cnt>1,cnt-1;否则,合并左右子树
query_rk(val);// 通过查询排名的操作找到这个节点并splay
if(cnt[rt]>1){
cnt[rt]--;
update(rt);
return;
}
if(!ch[rt][0]&&!ch[rt][1]){
clear(rt);
rt=0;
return;
}
if(!ch[rt][0]){
int tmp=rt;
rt=ch[rt][1];
fa[rt]=0;
clear(tmp);
return;
}
if(!ch[rt][1]){
int tmp=rt;
rt=ch[rt][0];
fa[rt]=0;
clear(tmp);
return;
}
int pos=rt,tmp=getpre();
fa[ch[pos][1]]=tmp;
ch[tmp][1]=ch[pos][1];
clear(pos);
update(rt);
// cout<<"del_out:"<<val<<endl;
}
查询前趋/后继
把要查询的值放在根上,然后操作
点击查看代码
int getpre(){
int pos=ch[rt][0];
if(!pos) return pos;
while(ch[pos][1]) pos=ch[pos][1];
splay(pos);
return pos;
}
int getnxt(){
int pos=ch[rt][1];
if(!pos) return pos;
while(ch[pos][0]) pos=ch[pos][0];
splay(pos);
return pos;
}
完整模板(P3369):
点击查看代码
#include<bits/stdc++.h>
#define N 100005
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
using namespace std;
int rt,tot,n;
int fa[N],ch[N][2],key[N],siz[N],cnt[N];
void update(int x);
bool get(int x);
void clear(int x);
void rotate(int x);
void splay(int x);
void ins(int val);
void del(int val);
int query_rk(int val);
int query_kth(int val);
int getpre();
int getnxt();
int main(){
cin>>n;
for(int i=1;i<=n;i++){
int opt,x;
scanf("%d%d",&opt,&x);
switch(opt){
case 1: ins(x);break;
case 2: del(x);break;
case 3: printf("%d\n",query_rk(x));break;
case 4: printf("%d\n",query_kth(x));break;
case 5: ins(x);printf("%d\n",key[getpre()]);del(x);break;
case 6: ins(x);printf("%d\n",key[getnxt()]);del(x);break;
}
}
}
//基础操作:更新节点siz值;判断是哪个儿子;销毁节点所有信息
void update(int x){
siz[x]=siz[lc(x)]+siz[rc(x)]+cnt[x];
}
bool get(int x){
return x==rc(fa[x]);
}
void clear(int x){
lc(x)=rc(x)=fa[x]=key[x]=siz[x]=cnt[x]=0;
}
//维持平衡用操作:旋转;splay操作
void rotate(int x){
// cout<<"rotate:"<<x<<endl;
int f=fa[x],gf=fa[f],son=get(x),son1=get(f);
ch[f][son]=ch[x][son^1];
if(ch[x][son^1]) fa[ch[x][son^1]]=f;
ch[x][son^1]=f;
fa[f]=x;fa[x]=gf;
if(gf) ch[gf][son1]=x;
update(x);update(f);
// 就是在保证二叉搜索树性质不变的前提下使一个节点向上移一层
}
void splay(int x){
// cout<<"splay:"<<x<<endl;
// 其实就是不断旋转某节点使其成为根节点并在途中通过旋转让途中的节点平衡
for(int i=fa[x];i=fa[x],i;rotate(x))
if(fa[i]) rotate(get(x)==get(i)?i:x);
// 三种情况:
// 第一种:父亲是根,直接旋转x即可
// 第二种:父亲,祖父和x共线,为避免破坏链结构,需要先转父亲再转x
// 第三种:父亲,祖父和x不共线,直接旋转即可,可以减少层数
rt=x;
// cout<<"splay_out:"<<x<<endl;
}
//查询操作:插入;删除;查询数的排名;查询第k个数;查前驱;查后缀
void ins(int val){
// 如果没有根节点,直接插入
if(!rt){
key[++tot]=val;
++cnt[tot];
rt=tot;
update(rt);
return;
}
// 否则按查找的形式向下走,遇到等于的cnt加1,走到空的新增节点
int pos=rt,pre=0;
while(true){
if(key[pos]==val){
cnt[pos]++;
update(pos);
update(fa[pos]);
splay(pos);
return;
}
pre=pos;
pos=ch[pos][key[pos]<val];
if(!pos){
key[++tot]=val;
++cnt[tot];
fa[tot]=pre;
ch[pre][key[pre]<val]=tot;
update(tot);
update(pre);
splay(tot);
return;
}
}
}
int query_rk(int val){
int pos=rt,ans=0;
while(true){
if(val<key[pos]) pos=ch[pos][0];
else{
ans+=siz[ch[pos][0]];
if(val==key[pos]){
splay(pos);
return ans+1;
}
ans+=cnt[pos];
pos=ch[pos][1];
}
}
}
int query_kth(int val){
int pos=rt;
while(true){
if(ch[pos][0]&&val<=siz[ch[pos][0]]) pos=ch[pos][0];
else{
val-=cnt[pos]+siz[ch[pos][0]];
if(val<=0){
splay(pos);
return key[pos];
}
pos=ch[pos][1];
}
}
}
int getpre(){
int pos=ch[rt][0];
if(!pos) return pos;
while(ch[pos][1]) pos=ch[pos][1];
splay(pos);
return pos;
}
int getnxt(){
int pos=ch[rt][1];
if(!pos) return pos;
while(ch[pos][0]) pos=ch[pos][0];
splay(pos);
return pos;
}
void del(int val){
// 先把x旋转到根的位置,如果cnt>1,cnt-1;否则,合并左右子树
query_rk(val);// 通过查询排名的操作找到这个节点并splay
if(cnt[rt]>1){
cnt[rt]--;
update(rt);
return;
}
if(!ch[rt][0]&&!ch[rt][1]){
clear(rt);
rt=0;
return;
}
if(!ch[rt][0]){
int tmp=rt;
rt=ch[rt][1];
fa[rt]=0;
clear(tmp);
return;
}
if(!ch[rt][1]){
int tmp=rt;
rt=ch[rt][0];
fa[rt]=0;
clear(tmp);
return;
}
int pos=rt,tmp=getpre();
fa[ch[pos][1]]=tmp;
ch[tmp][1]=ch[pos][1];
clear(pos);
update(rt);
}
无旋Treap
无旋treap是一种只有两种基本操作——分裂和合并的平衡树,但功能十分强大,方便维护区间问题,且好写
treap是一种通过随机来维护平衡的平衡树(分配一个随机的pri值给每个元素),它同时满足BST和堆的性质,搜索树层数的期望值是 \(\log n\) (感性理解)
split
分裂操作分为两种,一种为按值分裂,一种是按排名分裂
第一种适合处理数值更改问题,第二种适合处理区间问题
按值分裂:
就是一棵树分为两颗新的树
点击查看代码
void split(int pos,int &l,int &r,int val){
if(!pos){
l=0,r=0;
return;
// 如果分裂到底,返回
}
if(key[pos]<=val){
l=pos;
split(ch[pos][1],ch[pos][1],r,val);
// 如果当前权值小于分裂所需权值,所以右子树一定不在分裂后总的左子树中
// 传右子树的参数到下一个的l中,因为之后的左子树也有可能出现小于val的节点
}
else{
r=pos;
split(ch[pos][0],l,ch[pos][0],val);
// 与上面完全相反
}
update(pos);
}
按排名分裂:
与按值分裂相似,只不过换成了排名
点击查看代码
void split(int pos,int &l,int &r,int cnt){
if(!pos){
l=0,r=0;
return;
}
if(siz[ch[pos][0]]<cnt){
l=pos;
split(ch[pos][1],ch[pos][1],r,cnt-siz[ch[pos][0]]-1);
}
else{
r=pos;
split(ch[pos][0],l,ch[pos][0],cnt);
}
update(pos);
}
merge
点击查看代码
int merge(int l,int r){
if(!l||!r) return l+r;
if(pri[l]<pri[r]){
ch[l][1]=merge(ch[l][1],r);
update(l);return l;
// l的pri值小于r的,满足小根堆性质,为了满足平衡树性质,把r与l右子树合并
}
else{
ch[r][0]=merge(l,ch[r][0]);
update(r);return r;
// 与上面相反
}
}
例题
P1486 郁闷的出纳员
利用treap按值分裂把不满足条件的去除即可
点击查看代码
#include<bits/stdc++.h>
#define N 100005
using namespace std;
int n,minn,tot,root,leave;
int ch[N][2],key[N],pri[N],siz[N];
void update(int x);
void split(int pos,int &l,int &r,int val);
int merge(int l,int r);
void add(int val);
void ins(int val);
void del(int val);
int getkth(int val);
inline int read(){
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*f;
}
int main(){
srand(time(0));
cin>>n>>minn;
getchar();
for(int i=1;i<=n;i++){
char opt=getchar();
int k=read();
if(opt=='I'&&k>=minn) ins(k);
if(opt=='A')
for(int i=1;i<=tot;i++)
key[i]+=k;
if(opt=='S')
for(int i=1;i<=tot;i++)
key[i]-=k;
if(opt=='F') printf("%d\n",(k<=siz[root]?getkth(siz[root]-k+1):-1));
int tmp1,tmp2;
leave+=siz[root];
split(root,tmp1,tmp2,minn-1);
root=tmp2;
leave-=siz[root];
}
cout<<leave;
}
void update(int x){
siz[x]=1+siz[ch[x][0]]+siz[ch[x][1]];
}
void split(int pos,int &l,int &r,int val){
if(!pos){
l=0,r=0;
return;
}
if(key[pos]<=val){
l=pos;
split(ch[pos][1],ch[pos][1],r,val);
}
else{
r=pos;
split(ch[pos][0],l,ch[pos][0],val);
}
update(pos);
}
int merge(int l,int r){
if(!l||!r) return l+r;
if(pri[l]<pri[r]){
ch[l][1]=merge(ch[l][1],r);
update(l);return l;
}
else{
ch[r][0]=merge(l,ch[r][0]);
update(r);return r;
}
}
void add(int val){
siz[++tot]=1;
key[tot]=val;
pri[tot]=rand();
}
void ins(int val){
add(val);
int x,y,z=tot;
split(root,x,y,val);
root=merge(merge(x,z),y);
}
int getkth(int val){
int pos=root;
while(true){
if(val<=siz[ch[pos][0]]) pos=ch[pos][0];
else{
val-=1+siz[ch[pos][0]];
if(val==0) return key[pos];
pos=ch[pos][1];
}
}
}
P3391 文艺平衡树
首先可以发现一个性质:对于一个BST来说,如果反转所有节点的左右子树,则它的中序遍历也将反转(感性理解:可以看作是BST所维护的大小关系颠倒,变为左子树比当前节点大,右子树比当前节点小,故所维护的中序遍历也倒过来)
在这道题中,每个节点在treap中的排名意义为在序列中的位置,所以区间操作直接把对应区间的树按排名分裂出来即可,然后按照类似线段树的操作打懒标记,来确保时间复杂度的正确性,在合并和分裂的时候做好pushdown即可,最后输出中序遍历即为序列
点击查看代码
#include<bits/stdc++.h>
#define N 100005
using namespace std;
int n,m,tot,rt;
int ch[N][2],key[N],pri[N],siz[N],tag[N];
void update(int x);
void pushdown(int x);
void split(int pos,int &l,int &r,int cnt);
int merge(int l,int r);
void add(int val);
void ins(int val);
void rever(int l,int r);
void print(int pos);
int main(){
srand((unsigned)time(0));
cin>>n>>m;
for(int i=1;i<=n;i++) ins(i);
for(int i=1;i<=m;i++){
int l,r;
scanf("%d%d",&l,&r);
rever(l,r);
}
print(rt);
}
void update(int x){
siz[x]=1+siz[ch[x][0]]+siz[ch[x][1]];
}
void pushdown(int x){
if(!tag[x]) return;
swap(ch[x][0],ch[x][1]);
tag[ch[x][0]]^=1;
tag[ch[x][1]]^=1;
tag[x]=0;
}
void split(int pos,int &l,int &r,int cnt){
if(!pos){
l=0,r=0;
return;
}
pushdown(pos);
if(siz[ch[pos][0]]<cnt){
l=pos;
split(ch[pos][1],ch[pos][1],r,cnt-siz[ch[pos][0]]-1);
}
else{
r=pos;
split(ch[pos][0],l,ch[pos][0],cnt);
}
update(pos);
}
int merge(int l,int r){
if(!l||!r) return l+r;
if(pri[l]<pri[r]){
pushdown(l);
ch[l][1]=merge(ch[l][1],r);
update(l);return l;
}
else{
pushdown(r);
ch[r][0]=merge(l,ch[r][0]);
update(r);return r;
}
}
void add(int val){
siz[++tot]=1;
key[tot]=val;
pri[tot]=rand()*rand();
}
void ins(int val){
add(val);
int x,y,z=tot;
split(rt,x,y,val);
rt=merge(merge(x,z),y);
}
void rever(int l,int r){
int x,y,z;
split(rt,x,y,l-1);
split(y,y,z,r-l+1);
tag[y]^=1;
merge(x,merge(y,z));
}
void print(int pos){
if(!pos) return;
pushdown(pos);
print(ch[pos][0]);
printf("%d ",key[pos]);
print(ch[pos][1]);
}
线段树套平衡树
对于区间上进行平衡树操作可以考虑将线段树的每个节点都变成平衡树维护
- 操作一:对 \([l,r]\) 有多少个比 \(x\) 小求和再加1
- 操作二:二分答案,然后通过操作一计算排名来进行判断
- 操作三:和线段树单点修改类似,把含有这一位置的节点的平衡树全部修改
- 操作四和五:可以通过操作二实现(还有一种实现方法为对区间的前趋/后继取最大/最小值)
点击查看代码
#include<bits/stdc++.h>
#define lc (pos<<1)
#define rc ((pos<<1)|1)
#define N 50005
#define mid ((l+r)>>1)
using namespace std;
int n,m,a[N];
int siz[N<<5],key[N<<5],pri[N<<5],ch[N<<5][2],fa[N<<5],tot;
struct Treap{
int rt;
void update(int x){
siz[x]=1+siz[ch[x][0]]+siz[ch[x][1]];
}
void split(int x,int &l,int &r,int val){
if(!x){
l=r=0;
return;
}
if(key[x]<=val){
l=x;
split(ch[x][1],ch[x][1],r,val);
}
else{
r=x;
split(ch[x][0],l,ch[x][0],val);
}
update(x);
}
int merge(int l,int r){
if(!l||!r) return l+r;
if(pri[l]<pri[r]){
ch[l][1]=merge(ch[l][1],r);
update(l);return l;
}
else{
ch[r][0]=merge(l,ch[r][0]);
update(r);return r;
}
}
void add(int val){
key[++tot]=val;
siz[tot]=1;
pri[tot]=rand()*rand();
}
void ins(int val){
add(val);int x,y;
split(rt,x,y,val);
rt=merge(merge(x,tot),y);
}
void del(int val){
int x,y,z;
split(rt,x,z,val);
split(x,x,y,val-1);
y=merge(ch[y][0],ch[y][1]);
rt=merge(merge(x,y),z);
}
int getrk(int val){
int x,y,ans;
split(rt,x,y,val-1);
ans=siz[x];
rt=merge(x,y);
return ans;
}
};
struct SMT{
Treap node[N<<2];
void build(int pos,int l,int r){
for(int i=l;i<=r;i++)
node[pos].ins(a[i]);
if(l==r) return;
build(lc,l,mid);
build(rc,mid+1,r);
}
void update(int pos,int l,int r,int x,int val){
node[pos].del(a[x]);
node[pos].ins(val);
if(l==r) return;
if(x<=mid) update(lc,l,mid,x,val);
else update(rc,mid+1,r,x,val);
}
int query_rk(int pos,int l,int r,int L,int R,int val){
if(r<L||R<l) return 0;
if(L<=l&&r<=R) return node[pos].getrk(val);
return query_rk(lc,l,mid,L,R,val)+query_rk(rc,mid+1,r,L,R,val);
}
int kth(int L,int R,int val){
int l=-1e8,r=1e8,ans;
while(l<=r){
if(query_rk(1,1,n,L,R,mid)+1>val){
ans=mid;
r=mid-1;
}
else l=mid+1;
}
return ans;
}
int getpre(int L,int R,int val){
return kth(L,R,query_rk(1,1,n,L,R,val));
}
int getnxt(int L,int R,int val){
return kth(L,R,query_rk(1,1,n,L,R,val+1)+1);
}
}T;
int main(){
srand(time(0));
cin>>n>>m;
for(int i=1;i<=n;i++)
scanf("%d",a+i);
T.build(1,1,n);
for(int i=1;i<=m;i++){
int opt;
scanf("%d",&opt);
int l,r,pos,k;
switch(opt){
case 1: scanf("%d%d%d",&l,&r,&k);printf("%d\n",T.query_rk(1,1,n,l,r,k)+1);break;
case 2: scanf("%d%d%d",&l,&r,&k);printf("%d\n",T.kth(l,r,k)-1);break;
case 3: scanf("%d%d",&pos,&k);T.update(1,1,n,pos,k);a[pos]=k;break;
case 4: scanf("%d%d%d",&l,&r,&k);printf("%d\n",T.getpre(l,r,k)-1);break;
case 5: scanf("%d%d%d",&l,&r,&k);printf("%d\n",T.getnxt(l,r,k)-1);break;
}
}
}