平衡树之Splay
首先,splay是一种通过大力证明可证的严格nlogn的数据结构(由于我现在不会所以现在不证)
spaly是一种BST(二叉搜索树),因此它具有BST的所有性质(最主要的就是中序遍历表示节点从小到大);
为了能更好的理解splay,一定要先学BST,否则会不清楚splay的一些由BST得到的性质;
平衡树,首先它是一棵二叉树,因此对于每个节点i,要记录i的左儿子ch[0],和右儿子ch[1];
为了方便splay的操作,我们还应该记录它的父亲节点;
然后可以在每个节点上根据题意维护一些乱七八糟的东西:
接下来就用普通平衡树这道题当作spaly的例题;
第x个点维护sum[x]表示x子树中所有值出现的次数,lazy[x]表示x这个点表示的值出现的次数,v[x]表示编号为x的点所代表的权值是多少;
class node{ public: int v; int father; int ch[3]; int sum,lazy; }BST[100010];
然后介绍splay的函数;
1.就像线段树那样,我们需要一个函数updata(),通过这个点的儿子的信息来更新这个点的信息;
void updata(int x){BST[x].sum=(BST[BST[x].ch[0]].sum+BST[BST[x].ch[1]].sum+BST[x].lazy);}
updata()看不懂还是快回去学线段树吧;
2.为了操作的方便(以后再说),我们需要一个函数identify(),确定一个点和它父亲的父子关系(左儿子还是右儿子);
int identify(int x){return BST[BST[x].father].ch[0]==x?0:1;}
identify()看不懂去学三目运算符吧;
3.为了操作的方便,我们需要一个函数connect(),使得点x和点y产生某种父子关系;
void connect(int x,int fa,int judge){BST[x].father=fa;BST[fa].ch[judge]=x;}
其中judge表示x是fa的那个儿子,左儿子是0,右儿子是1;
4.为了达到nlogn的复杂度,我们需要一个操作,既破坏了原来的结构,又产生了一个新的结构,并且新的结构满足BST的性质,那就是rotate()操作;
如图,上图表示的就是右旋操作;
对于x来说,y,c的值都要比x的值大,所以当x转到y的位置上时,理所当然的,y以及c的子树完全可以保持原有的形状变成x的右儿子
然后再次观察:v[a]<v[x]<v[b]; 好吧,他们不能放到x的同一个方向的子树上,那么就不动了;
欸?等等,x的右儿子是y准没错,左儿子是a也是没错的,那么b呢?中儿子?这是在骗小孩子吗?
所以现在急需一个神奇的操作,把b及其子树安排妥当;
我们发现,v[b]>v[x]; 并且由于v[x]<v[y],v[a]<v[y],v[b]<v[y];所以v[b]<v[y];
因为x是原来y的左儿子,所以右旋后y的左儿子是空的,正好可以把b放到y的左儿子的位置;
好了,右旋操作说完了,左旋操作与其类似;为了方便,采用异或操作把左旋和右旋结合到一个函数中;
void rotate(int x){ int y=BST[x].father,faroot=BST[y].father,fajudge=identify(y),yjudge=identify(x),son=BST[x].ch[yjudge^1]; connect(son,y,yjudge); connect(y,x,yjudge^1); connect(x,faroot,fajudge); updata(y); updata(x); }
5.这个函数便是splay的精华,也是splay名字的来源:splay()函数(伸展操作)
为了达到不知道怎么证明的严格nlogn的复杂度,我们需要咋一些操作后,把操作的点通过旋转操作转到根节点上,这就是splay操作;
对于一个点x,设它的父亲是y,父亲的父亲是z,如果要将他旋转到z那里,存在两种方法:
1).先旋转y,再旋转x; 2).旋转两次x;
如果x与y和y与z的父子关系一样,我们就采用第一种旋转方式,否则采用第二种旋转方式;
这样做的原因就是为了保证时间复杂度是nlogn,否则可能退化成链;
6.我们需要再原有的splay上新建一个节点,所以存在函数crepoint();
int crepoint(int v,int fa){ ++n;BST[n].v=v;BST[n].father=fa;BST[n].sum=BST[n].lazy=1; return n; }
这个函数应该就不用解释了;
7.我们同样需要一个函数destroy(),用来再原splay中摧毁一个节点得到一个新的splay;
void destroy(int x){ BST[x].v=0; BST[x].father=0; BST[x].ch[0]=0; BST[x].ch[1]=0; BST[x].sum=0; BST[x].lazy=0; if(x==n) --n; }
这个函数在最后可以写个啥用没有的优化;
8.根据BST的性质,我们可以定义一个查找函数find(),用来查找值为x的节点的编号是多少;
int find(int goal){ int now=BST[0].ch[1]; while(1){ if(BST[now].v==goal){ splay(now,BST[0].ch[1]); return now; } int nxt=goal<BST[now].v?0:1; if(!BST[now].ch[nxt]) return 0; now=BST[now].ch[nxt]; } }
这个函数要是看不明白就自己回去学BST吧;
注意一个细节,当我们找到了一个节点的值是goal的时候,我们需要将这个点通过伸展操作一路转到根节点;(这是为了保证时间复杂度,也方便了以后的操作)
9.光有crepoint()函数可不够,因为我们不知道新加入的节点在splay中的准确位置,所以我们还需要一个build()函数;(先看代码吧)
int build(int v){ points++; if(n==0){ BST[0].ch[1]=1; crepoint(v,0); } else{ int now=BST[0].ch[1]; while(1){ BST[now].sum++; if(v==BST[now].v){ ST[now].lazy++; return now; } int nxt=v<BST[now].v?0:1; if(!BST[now].ch[nxt]){ crepoint(v,now); BST[now].ch[nxt]=n; return n; } now=BST[now].ch[nxt]; } } return 0; }
我们需要一个小小的特判,如果这个点是第一个加入的点,我们需要把他设为根节点;
然后我们在splay中按照BST的性质不断向下找到属于自己的位置,如果没找到,那么就新建一个节点;
10.如果只有build()函数插入新节点,我们并无法保证时间复杂度是nlogn,所以我们需要push()函数,在build()之后将新加入的点splay(一路旋转)到根节点的位置;
void push(int v){int now=build(v);splay(now,BST[0].ch[1]);}
11.删除操作:pop()函数;
void pop(int v){ int deal=find(v); if(!deal) return; points--; if(BST[deal].lazy>1){ BST[deal].lazy--; BST[deal].sum--; return; } if(!BST[deal].ch[0]){ BST[0].ch[1]=BST[deal].ch[1]; BST[BST[0].ch[1]].father=0; } else{ int lef=BST[deal].ch[0]; while(BST[lef].ch[1]) lef=BST[lef].ch[1]; splay(lef,BST[deal].ch[0]); int rig=BST[deal].ch[1]; connect(rig,lef,1); connect(lef,0,1); updata(lef); } destroy(deal); }
如果我们要删除一个权值为v的节点,首先通过find()函数找到权值为v的点的节点编号x,如果节点x所表示的值的出现次数大于1次,那么就不需要删除节点,仅仅次数减1就好了;
注意到,通过find()函数,我们把要删除的点转到了根节点的位置,如果转完后根节点没有左子树,那么就直接删除根节点就好,把转完后根节点的右儿子当作转完后的根节点。可以发现,这样并不会破坏BST的性质;
如果有左子树呢?删除操作就没法进行了吗?显然不是!
对于转完后的根节点x,它的左儿子是l,右儿子是r;
我们将l子树中权值最大值的点y(一直走右儿子直到走不了)转到l所在的位置,由于y的权值是l子树最大的,所以y节点不存在右儿子;并且我们知道:v[y]<v[x]<v[r];
那么在删除节点x之后,我们将y作为splay的新根,因为y不存在右儿子,所以把r当作y的右儿子;
可以证明,这样做并不会破坏BST的性质;
至此,所有splay的常规操作都已经结束了,完结撒花~
但仅有常规操作无法完成一些题目要求,所以我们需要功能操作,但这与BST的基本功能操作就一至了,所以主要讲解splay的本文就不多介绍了,不会的可以看代码理解;
12.查询排名为x的值: atrank();
int atrank(int x){ if(x>points) return -999999999; int now=BST[0].ch[1]; while(1){ int tmp=BST[now].sum-BST[BST[now].ch[1]].sum; if(x>BST[BST[now].ch[0]].sum&&x<=tmp) break; if(x<tmp){ now=BST[now].ch[0]; } else{ x=x-tmp; now=BST[now].ch[1]; } } splay(now,BST[0].ch[1]); return BST[now].v; }
13.查询值为v的点在BST中的排名:rank();
int rank(int v){ int ans=0,now=BST[0].ch[1]; while(1){ if(BST[now].v==v){ int tmpp=ans+BST[BST[now].ch[0]].sum+1; splay(now,BST[0].ch[1]); return tmpp; } if(now==0) return 0; if(v<BST[now].v){ now=BST[now].ch[0]; } else{ ans+=BST[BST[now].ch[0]].sum+BST[now].lazy; now=BST[now].ch[1]; } } return 0; }
一定要注意,atrank()和rank()都需要在结束时进行splay操作,把操作的点转到根节点(为了保证时间复杂度是nlogn)
14.查询值为v的数的前驱(所有小于v的数中的最大值)
int upper(int v){ int now=BST[0].ch[1]; int result=999999999; while(now){ if(BST[now].v>v&&BST[now].v<result) result=BST[now].v; if(v<BST[now].v) now=BST[now].ch[0]; else{ now=BST[now].ch[1]; } } return result; }
15.查询值为v的数的后记(所有大于v的数中的最小值)
int lower(int v){ int now=BST[0].ch[1]; int result=-999999999; while(now){ if(BST[now].v<v&&BST[now].v>result) result=BST[now].v; if(v>BST[now].v) now=BST[now].ch[1]; else now=BST[now].ch[0]; } return result; }
然后就是愉悦的完整代码时间啦!(传送门:普通平衡树)
#include <bits/stdc++.h> #define inc(i,a,b) for(register int i=a;i<=b;i++) using namespace std; class Splay{ public: class node{ public: int v; int father; int ch[3]; int sum,lazy; }BST[100010]; int n,points; void updata(int x){BST[x].sum=(BST[BST[x].ch[0]].sum+BST[BST[x].ch[1]].sum+BST[x].lazy);} int identify(int x){return BST[BST[x].father].ch[0]==x?0:1;} void connect(int x,int fa,int judge){BST[x].father=fa;BST[fa].ch[judge]=x;} void rotate(int x){ int y=BST[x].father,faroot=BST[y].father,fajudge=identify(y),yjudge=identify(x),son=BST[x].ch[yjudge^1]; connect(son,y,yjudge); connect(y,x,yjudge^1); connect(x,faroot,fajudge); updata(y); updata(x); } void splay(int from,int to){ to=BST[to].father; while(BST[from].father!=to){ int fa=BST[from].father; if(BST[fa].father==to) rotate(from); else if(identify(from)==identify(fa)){ rotate(fa); rotate(from); } else{ rotate(from); rotate(from); } } } int crepoint(int v,int fa){ ++n;BST[n].v=v;BST[n].father=fa;BST[n].sum=BST[n].lazy=1; return n; } void destroy(int x){ BST[x].v=0; BST[x].father=0; BST[x].ch[0]=0; BST[x].ch[1]=0; BST[x].sum=0; BST[x].lazy=0; if(x==n) --n; } int find(int goal){ int now=BST[0].ch[1]; while(1){ if(BST[now].v==goal){ splay(now,BST[0].ch[1]); return now; } int nxt=goal<BST[now].v?0:1; if(!BST[now].ch[nxt]) return 0; now=BST[now].ch[nxt]; } } int build(int v){ points++; if(n==0){ BST[0].ch[1]=1; crepoint(v,0); } else{ int now=BST[0].ch[1]; while(1){ BST[now].sum++; if(v==BST[now].v){ BST[now].lazy++; return now; } int nxt=v<BST[now].v?0:1; if(!BST[now].ch[nxt]){ crepoint(v,now); BST[now].ch[nxt]=n; return n; } now=BST[now].ch[nxt]; } } return 0; } void push(int v){int now=build(v);splay(now,BST[0].ch[1]);} void pop(int v){ int deal=find(v); if(!deal) return; points--; if(BST[deal].lazy>1){ BST[deal].lazy--; BST[deal].sum--; return; } if(!BST[deal].ch[0]){ BST[0].ch[1]=BST[deal].ch[1]; BST[BST[0].ch[1]].father=0; } else{ int lef=BST[deal].ch[0]; while(BST[lef].ch[1]) lef=BST[lef].ch[1]; splay(lef,BST[deal].ch[0]); int rig=BST[deal].ch[1]; connect(rig,lef,1); connect(lef,0,1); updata(lef); } destroy(deal); } int rank(int v){ int ans=0,now=BST[0].ch[1]; while(1){ if(BST[now].v==v){ int tmpp=ans+BST[BST[now].ch[0]].sum+1; splay(now,BST[0].ch[1]); return tmpp; } if(now==0) return 0; if(v<BST[now].v){ now=BST[now].ch[0]; } else{ ans+=BST[BST[now].ch[0]].sum+BST[now].lazy; now=BST[now].ch[1]; } } return 0; } int atrank(int x){ if(x>points) return -999999999; int now=BST[0].ch[1]; while(1){ int tmp=BST[now].sum-BST[BST[now].ch[1]].sum; if(x>BST[BST[now].ch[0]].sum&&x<=tmp) break; if(x<tmp){ now=BST[now].ch[0]; } else{ x=x-tmp; now=BST[now].ch[1]; } } splay(now,BST[0].ch[1]); return BST[now].v; } int upper(int v){ int now=BST[0].ch[1]; int result=999999999; while(now){ if(BST[now].v>v&&BST[now].v<result) result=BST[now].v; if(v<BST[now].v) now=BST[now].ch[0]; else{ now=BST[now].ch[1]; } } return result; } int lower(int v){ int now=BST[0].ch[1]; int result=-999999999; while(now){ if(BST[now].v<v&&BST[now].v>result) result=BST[now].v; if(v>BST[now].v) now=BST[now].ch[1]; else now=BST[now].ch[0]; } return result; } }Stree; int main() { //freopen("splay.in","r",stdin); //freopen("my.out","w",stdout); int n; cin>>n; Stree.push(999999999); inc(i,1,n){ int type,x; scanf("%d%d",&type,&x); if(type==1){ Stree.push(x); } else if(type==2){ Stree.pop(x); } else if(type==3){ printf("%d\n",Stree.rank(x)); } else if(type==4){ printf("%d\n",Stree.atrank(x)); } else if(type==5){ printf("%d\n",Stree.lower(x)); } else{ printf("%d\n",Stree.upper(x)); } } }