树套树浅谈
今天来说一下线段树套Splay。顺便我也来重新敲一遍模板。
首先,明确一下Splay套线段树用来处理什么问题。它可以支持:插入x,删除x,单点修改,查询x在区间[l,r]的排名,查询区间[l,r]中排名为k的数,以及一个数在区间[l,r]中的前驱,后继。(应该还可以查询区间和等东西,还没写过)
其实它的常数非常大,但是这是树套树中最容易理解的一种。
首先,我们知道,对于一个区间,我们给它建一棵线段树,它的每个节点维护的是区间[l,r]的信息。所以,我们对于每一个节点都建一棵Splay.(听起来就很暴力啊……)
不知道为啥我的指针版Splay套炸了,这里用普通版吧。
首先,我们先来完成最基本的Splay操作。
#include<cstdio> #include<iostream> #include<cstring> #include<string> #define inf (2147483647) using namespace std; const int MAXN=4e6+2; inline void IN(int &x){ int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9'){ if(ch=='-')w=-1; ch=getchar(); }while(ch>=48&&ch<='9'){ s=(s<<1)+(s<<3)+(ch^48); ch=getchar(); }x=s*w; } int n,m,a[MAXN],ans,MX; /*-------------------------Splay-------------------------------------*/ int f[MAXN],c[MAXN],s[MAXN],v[MAXN],ch[MAXN][2],rt[MAXN],tot; inline void Splay_del(int x){f[x]=s[x]=c[x]=v[x]=ch[x][1]=ch[x][0]=0;} inline void Splay_pushup(int x){s[x]=(ch[x][0]?s[ch[x][0]]:0)+(ch[x][1]?s[ch[x][1]]:0)+c[x];} inline void Splay_rotate(int x){ int y=f[x],z=f[y],k=ch[y][1]==x,v=ch[x][k^1]; ch[y][k]=v;if(v)f[v]=y;f[x]=z;if(z)ch[z][ch[z][1]==y]=x; f[y]=x;ch[x][k^1]=y;Splay_pushup(y),Splay_pushup(x); }
pushup、rotate和del我就不解释了,根据代码理解吧。
然后我们考虑Splay操作的写法。
实际上并没有什么区别,因为是多颗Splay,所以我们要多传一个当前Splay所在区间的编号,对应地,rt[i]就是原来Splay的rt。
代码:
inline void Splay(int i,int x,int top=0){ while(f[x]!=top){ int y=f[x],z=f[y]; if(z!=top)(ch[z][1]==y)^(ch[y][1]==x)?Splay_rotate(y):Splay_rotate(x); Splay_rotate(x); }if(!top)rt[i]=x; }
接下来,来讲插入。
对于插入,还是一样的流程,往下找,如果是空树,新建节点,赋值;找到值一样的,c[x]++,pushup;没有的话,新建即可。
注意这里也要传区间编号。
inline void Splay_Insert(int i,int x){ int pos=rt[i]; if(!rt[i]){ rt[i]=pos=++tot;v[tot]=x;s[pos]=c[pos]=1; f[pos]=ch[pos][0]=ch[pos][1]=0;return; }int last=0; while(727){ if(v[pos]==x){++c[pos];Splay_pushup(last);break;} last=pos;pos=ch[pos][x>v[pos]]; if(!pos){ pos=++tot;v[pos]=x;s[pos]=c[pos]=1; ch[last][x>v[last]]=pos; f[pos]=last;ch[pos][0]=ch[pos][1]=0; Splay_pushup(last);break; } }Splay(i,pos);return; }
727是生日不要在意……
然后是求某个值在区间[l,r]里的rank.
这个东西,我们把它放到线段树里面理解:如果当前x在区间[l,r]中的一段区间[L,R]中排名为k,在区间[L1,R1]中排名为t,并且这两个区间可以合并成大区间[l,r]的话,那么,当前值x在大区间[l,r]的排名就是k+t.
所以,一个Splay_rank支持查询当前rank,一个Seg函数查询总区间的合并之后的rank。
inline int Splay_rank(int i,int k){ int x=rt[i],cal=0; while(x){ if(v[x]==k)return cal+((ch[x][0])?s[ch[x][0]]:0); else if(v[x]<k){ cal+=((ch[x][0])?s[ch[x][0]]:0)+c[x];x=ch[x][1]; }else x=ch[x][0]; }return cal; } inline void Seg_rank(int x,int l,int r,int L,int R,int Kth){ if(l==L&&r==R){ans+=Splay_rank(x,Kth);return;} if(R<=mid)Seg_rank(lc,l,mid,L,R,Kth); else if(mid<L)Seg_rank(rc,mid+1,r,L,R,Kth); else Seg_rank(lc,l,mid,L,mid,Kth),Seg_rank(rc,mid+1,r,mid+1,R,Kth); }
注意,查询排名的时候,有三种情况,一个是全在左区间,一个是全在右区间,还有横跨两个区间的,不要忘记区分。
Splay找前驱后继就不说了吧……
inline int Splay_Get_pre(int i,int x){ int pos=rt[i];while(pos){ if(v[pos]<x){if(ans<v[pos])ans=v[pos];pos=ch[pos][1];} else pos=ch[pos][0]; }return ans; } inline int Splay_Get_suc(int i,int x){ int pos=rt[i];while(pos){ if(v[pos]>x){if(ans>v[pos])ans=v[pos];pos=ch[pos][0];} else pos=ch[pos][1]; }return ans; }
然后是删除。
先上代码:
inline int Splay_find(int i,int x){ int pos=rt[i];while(x){ if(v[pos]==x){Splay(i,pos);return pos;} pos=ch[pos][x>v[pos]]; }return 0; } inline int Splay_pre(int i){int x=ch[rt[i]][0];while(ch[x][1])x=ch[x][1];return x;} inline int Splay_suc(int i){int x=ch[rt[i]][1];while(ch[x][0])x=ch[x][0];return x;} inline void Splay_Delete(int i,int key){ int x=Splay_find(i,key); if(c[x]>1){--c[x];Splay_pushup(x);return;} if(!ch[x][0]&&!ch[x][1]){Splay_del(rt[i]);rt[i]=0;return;} if(!ch[x][0]){int y=ch[x][1];rt[i]=y;f[y]=0;return;} if(!ch[x][1]){int y=ch[x][0];rt[i]=y;f[y]=0;return;} int p=Splay_pre(i);int lastrt=rt[i]; Splay(i,p,0);ch[rt[i]][1]=ch[lastrt][1];f[ch[lastrt][1]]=rt[i]; Splay_del(lastrt);Splay_pushup(rt[i]); }
先找点,find函数,找完之后判断它是否还存在,如果存在,分类讨论:
左右孩子都没有,直接删掉。
只有右孩子,连接x的右孩子成为根的右孩子
只有左孩子,同理。
如果两个都有,则:
找到区间i的严格小于key的最大数,并记录根。
把这个数旋转到根,然后把之前记录的根的孩子和当前根的孩子连起来,认个爹,然后把之前根删掉,再pushup即可。
下面该线段树的查询操作了。
改天再补一发线段树的入门讲解……
首先是线段树的节点插入。
我们对每一个线段树的节点都插入v,注意是插入进它的Splay里面。
#define lc ((x)<<1) #define rc ((x)<<1|1) #define mid ((l+r)>>1) inline void Seg_Insert(int x,int l,int r,int pos,int val){ Splay_Insert(x,val);if(l==r)return; if(pos<=mid)Seg_Insert(lc,l,mid,pos,val); else Seg_Insert(rc,mid+1,r,pos,val); }
Easily!
继续。
来看看单点修改怎样维护。
注意此处,是要先把要改的点删掉在把修改的点加进来。
并且同时更新Splay.
inline void Seg_change(int x,int l,int r,int pos,int val){ Splay_Delete(x,a[pos]);Splay_Insert(x,val); if(l==r){a[pos]=val;return;} if(pos<=mid)Seg_change(lc,l,mid,pos,val); else Seg_change(rc,mid+1,r,pos,val); }
然后是查询区间[l,r]中v的前驱,后继。
对于每一个包含的区间中,都求一遍当前Splay中v的前驱后继,每次统计max或min即可,注意处理区间问题,三种情况分类讨论,见上。
inline void Seg_pre(int x,int l,int r,int L,int R,int val){ if(l==L&&r==R){ans=max(ans,Splay_Get_pre(x,val));return;} if(R<=mid)Seg_pre(lc,l,mid,L,R,val); else if(mid<L)Seg_pre(rc,mid+1,r,L,R,val); else Seg_pre(lc,l,mid,L,mid,val),Seg_pre(rc,mid+1,r,mid+1,R,val); } inline void Seg_suc(int x,int l,int r,int L,int R,int val){ if(l==L&&r==R){ans=min(ans,Splay_Get_suc(x,val));return;} if(R<=mid)Seg_suc(lc,l,mid,L,R,val); else if(mid<L)Seg_suc(rc,mid+1,r,L,R,val); else Seg_suc(lc,l,mid,L,mid,val),Seg_suc(rc,mid+1,r,mid+1,R,val); }
最后,查询区间排名为k的数。
此处,我们要用二分。
输入时,记录区间最大值。
L=0,R=MX+1。
二分。
每一次用ans记录当前二分值(M=(L+R)>>1)的排名。
如果ans<k,L=M,else R=mid+1.
最后返回的是L-1。
/*----------------ask-------------*/ inline int Get_Kth(int x,int y,int k){ int L=0,R=MX+1,M; while(L<R){ M=(L+R)>>1; ans=0;Seg_rank(1,1,n,x,y,M); if(ans<k)L=M+1;else R=M; }return L-1; }
最后上完整代码。
#include<cstdio> #include<iostream> #include<cstring> #include<string> #define inf (2147483647) using namespace std; const int MAXN=4e6+2; inline void IN(int &x){ int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9'){ if(ch=='-')w=-1; ch=getchar(); }while(ch>=48&&ch<='9'){ s=(s<<1)+(s<<3)+(ch^48); ch=getchar(); }x=s*w; } int n,m,a[MAXN],ans,MX; /*-------------------------Splay-------------------------------------*/ int f[MAXN],c[MAXN],s[MAXN],v[MAXN],ch[MAXN][2],rt[MAXN],tot; inline void Splay_del(int x){f[x]=s[x]=c[x]=v[x]=ch[x][1]=ch[x][0]=0;} inline void Splay_pushup(int x){s[x]=(ch[x][0]?s[ch[x][0]]:0)+(ch[x][1]?s[ch[x][1]]:0)+c[x];} inline void Splay_rotate(int x){ int y=f[x],z=f[y],k=ch[y][1]==x,v=ch[x][k^1]; ch[y][k]=v;if(v)f[v]=y;f[x]=z;if(z)ch[z][ch[z][1]==y]=x; f[y]=x;ch[x][k^1]=y;Splay_pushup(y),Splay_pushup(x); } inline void Splay(int i,int x,int top=0){ while(f[x]!=top){ int y=f[x],z=f[y]; if(z!=top)(ch[z][1]==y)^(ch[y][1]==x)?Splay_rotate(y):Splay_rotate(x); Splay_rotate(x); }if(!top)rt[i]=x; } inline void Splay_Insert(int i,int x){ int pos=rt[i]; if(!rt[i]){ rt[i]=pos=++tot;v[tot]=x;s[pos]=c[pos]=1; f[pos]=ch[pos][0]=ch[pos][1]=0;return; }int last=0; while(727){ if(v[pos]==x){++c[pos];Splay_pushup(last);break;} last=pos;pos=ch[pos][x>v[pos]]; if(!pos){ pos=++tot;v[pos]=x;s[pos]=c[pos]=1; ch[last][x>v[last]]=pos; f[pos]=last;ch[pos][0]=ch[pos][1]=0; Splay_pushup(last);break; } }Splay(i,pos);return; } inline int Splay_rank(int i,int k){ int x=rt[i],cal=0; while(x){ if(v[x]==k)return cal+((ch[x][0])?s[ch[x][0]]:0); else if(v[x]<k){ cal+=((ch[x][0])?s[ch[x][0]]:0)+c[x];x=ch[x][1]; }else x=ch[x][0]; }return cal; } inline int Splay_find(int i,int x){ int pos=rt[i];while(x){ if(v[pos]==x){Splay(i,pos);return pos;} pos=ch[pos][x>v[pos]]; }return 0; } inline int Splay_pre(int i){int x=ch[rt[i]][0];while(ch[x][1])x=ch[x][1];return x;} inline int Splay_suc(int i){int x=ch[rt[i]][1];while(ch[x][0])x=ch[x][0];return x;} inline int Splay_Get_pre(int i,int x){ int pos=rt[i];while(pos){ if(v[pos]<x){if(ans<v[pos])ans=v[pos];pos=ch[pos][1];} else pos=ch[pos][0]; }return ans; } inline int Splay_Get_suc(int i,int x){ int pos=rt[i];while(pos){ if(v[pos]>x){if(ans>v[pos])ans=v[pos];pos=ch[pos][0];} else pos=ch[pos][1]; }return ans; } inline void Splay_Delete(int i,int key){ int x=Splay_find(i,key); if(c[x]>1){--c[x];Splay_pushup(x);return;} if(!ch[x][0]&&!ch[x][1]){Splay_del(rt[i]);rt[i]=0;return;} if(!ch[x][0]){int y=ch[x][1];rt[i]=y;f[y]=0;return;} if(!ch[x][1]){int y=ch[x][0];rt[i]=y;f[y]=0;return;} int p=Splay_pre(i);int lastrt=rt[i]; Splay(i,p,0);ch[rt[i]][1]=ch[lastrt][1];f[ch[lastrt][1]]=rt[i]; Splay_del(lastrt);Splay_pushup(rt[i]); } /*----------------------Seg_Tree---------------------------*/ #define lc ((x)<<1) #define rc ((x)<<1|1) #define mid ((l+r)>>1) inline void Seg_Insert(int x,int l,int r,int pos,int val){ Splay_Insert(x,val);if(l==r)return; if(pos<=mid)Seg_Insert(lc,l,mid,pos,val); else Seg_Insert(rc,mid+1,r,pos,val); } inline void Seg_rank(int x,int l,int r,int L,int R,int Kth){ if(l==L&&r==R){ans+=Splay_rank(x,Kth);return;} if(R<=mid)Seg_rank(lc,l,mid,L,R,Kth); else if(mid<L)Seg_rank(rc,mid+1,r,L,R,Kth); else Seg_rank(lc,l,mid,L,mid,Kth),Seg_rank(rc,mid+1,r,mid+1,R,Kth); } inline void Seg_change(int x,int l,int r,int pos,int val){ Splay_Delete(x,a[pos]);Splay_Insert(x,val); if(l==r){a[pos]=val;return;} if(pos<=mid)Seg_change(lc,l,mid,pos,val); else Seg_change(rc,mid+1,r,pos,val); } inline void Seg_pre(int x,int l,int r,int L,int R,int val){ if(l==L&&r==R){ans=max(ans,Splay_Get_pre(x,val));return;} if(R<=mid)Seg_pre(lc,l,mid,L,R,val); else if(mid<L)Seg_pre(rc,mid+1,r,L,R,val); else Seg_pre(lc,l,mid,L,mid,val),Seg_pre(rc,mid+1,r,mid+1,R,val); } inline void Seg_suc(int x,int l,int r,int L,int R,int val){ if(l==L&&r==R){ans=min(ans,Splay_Get_suc(x,val));return;} if(R<=mid)Seg_suc(lc,l,mid,L,R,val); else if(mid<L)Seg_suc(rc,mid+1,r,L,R,val); else Seg_suc(lc,l,mid,L,mid,val),Seg_suc(rc,mid+1,r,mid+1,R,val); } /*----------------ask-------------*/ inline int Get_Kth(int x,int y,int k){ int L=0,R=MX+1,M; while(L<R){ M=(L+R)>>1; ans=0;Seg_rank(1,1,n,x,y,M); if(ans<k)L=M+1;else R=M; }return L-1; } /*-----------------Main--------------------*/ int main(){ IN(n),IN(m); for(register int i=1;i<=n;++i){IN(a[i]);Seg_Insert(1,1,n,i,a[i]);MX=max(MX,a[i]);} while(m--){ int op,x,y,v;IN(op),IN(x),IN(y); switch(op){ case 1:{IN(v);ans=0;Seg_rank(1,1,n,x,y,v);printf("%d\n",ans+1);}break; case 2:{IN(v);printf("%d\n",Get_Kth(x,y,v));}break; case 3:{Seg_change(1,1,n,x,y);}break; case 4:{IN(v);ans=-inf;Seg_pre(1,1,n,x,y,v);printf("%d\n",ans);}break; case 5:{IN(v);ans=inf;Seg_suc(1,1,n,x,y,v);printf("%d\n",ans);}break; } }return 0; }
对于习题可以去洛谷的试炼场里找。我这个蒟蒻还是要多练啊……