文艺平衡Splay树学习笔记(2)
本blog会讲一些简单的Splay的应用,包括但不局限于
1. Splay 维护数组下标,支持区间reserve操作,解决区间问题
2. Splay 的启发式合并(按元素多少合并)
3. 线段树+Splay 大常数树套树
一、Splay维护区间下标解决区间翻转问题
思想: 对于数组的下标是不可重复的,我们使用平衡树维护下标,利用Splay的splay操作,让区间都在一棵子树内。
然后直接输出这颗子树的维护信息,由于维护的是子树信息,那么父亲的信息一定可以由两个儿子推出。
于是就可以类似于线段树的操作了(用儿子信息维护父亲信息)。
[LuoGu 模板] 文艺平衡树 : 维护区间翻转,输出最终区间形态。
Solution :
考虑splay函数的实质,就是把节点x转到目标节点,并且让Splay树保持BST的性质。
那么显然,对于维护序列下标,其维护的值便是元素对于数组的位置。(一个节点左边儿子子树的位置都小于该节点)
对于一个区间进行翻转,那么事实上就是 把l和r下标交换,l+1和r-1下标交换...
那么如何确定这个区间呢,不妨把[l-1]转到根节点,[r+1]转到根节点的右儿子,那么对于根节点右儿子的左儿子所在子树
就可以描述为[l,r]区间。
然而,对于Splay本身来说是不能出现下标为0的情况的,否则会死循环,在(1)中已经讲的非常明确了,所以我们仍然加入一个哨兵节点,对于所有区间均+1
在代码中的实现是这样的。
需要注意的是,特殊的,对于维护位置(下标)是不会出现重复元素的,那么所有节点cnt的值都是1,就省略了。
最后这颗Splay的中序遍历就是最终区间。
# include <bits/stdc++.h> # define inf (0x3f3f3f3f) using namespace std; const int N=2e5+10; inline int read() { int X=0,w=0; char c=0; while(c<'0'||c>'9') {w|=c=='-';c=getchar();} while(c>='0'&&c<='9') X=(X<<3)+(X<<1)+(c^48),c=getchar(); return w?-X:X; } void write(int x) { if (x<0) putchar('-'),x=-x; if (x>9) write(x/10); putchar(x%10+'0'); } void writeln(int x){ write(x);putchar('\n'); } struct Splay{ # define ls(x) (t[x][0]) # define rs(x) (t[x][1]) int t[N][2],size[N],val[N],par[N]; bool rev[N]; int root,tot; Splay(){ root=0,tot=0;insert(-inf);insert(inf);} int check (int x) { return rs(par[x])==x;} void pushup(int x) { size[x]=size[ls(x)]+size[rs(x)]+1;} void rotate(int x) { int y=par[x],k=check(x); t[y][k]=t[x][k^1]; par[t[x][k^1]]=y; t[x][k^1]=y; t[par[y]][check(y)]=x; par[x]=par[y]; par[y]=x; pushup(y); pushup(x); } void splay(int x,int goal=0) { while (par[x]!=goal) { int y=par[x],z=par[y]; if (z!=goal) rotate(check(x)==check(y)?y:x); rotate(x); } if (!goal) root=x; } void down(int x) { if (rev[x]) { swap(ls(x),rs(x)); rev[ls(x)]^=1; rev[rs(x)]^=1; rev[x]=0; } } void insert(int x){ int cur=root,p=0; while (cur && val[cur]!=x) p=cur,cur=t[cur][x>val[cur]]; cur=++tot; if (p) t[p][x>val[p]]=cur; t[cur][0]=t[cur][1]=0; val[cur]=x; par[cur]=p; size[cur]=1; splay(cur); } int kth_id(int k) { int cur=root; while (true) { down(cur); if (t[cur][0] && k<=size[ls(cur)]) cur=ls(cur); else if (k>size[ls(cur)]+1) { k-=size[ls(cur)]+1; cur=rs(cur); } else return cur; } } void reserve(int l,int r) { int x=kth_id(l),y=kth_id(r+2); splay(x); splay(y,x); rev[ls(y)]^=1; } void print(int x) { down(x); if (ls(x)) print(ls(x)); if (val[x]!=-inf&&val[x]!=inf) printf("%d ",val[x]); if (rs(x)) print(rs(x)); } }tr; int main() { int n=read(),m=read(); for (int i=1;i<=n;i++) tr.insert(i); for (int i=1;i<=m;i++) { int l=read(),r=read(); tr.reserve(l,r); } tr.print(tr.root); return 0; }
用Splay维护3个操作
1 L R V : 区间[L,R]每个元素+V
2 L R :将区间[L,R]翻转
3 L R :输出区间[L,R]的最小值
Hint : 元素初始值都为 0
Solution:
显然我们可以使用上述方法维护一棵平衡树,还是维护下标,我们发现splay或者rotate一下,树维护下标的性质是不会改变的。
所以仍然可以在节点维护一个加法标记,当翻转前下传即可。
# include<bits/stdc++.h> # define inf (0x3f3f3f3f) using namespace std; const int N=1e5+10; int n,m,a[N]; struct Splay{ # define ls(x) (t[x][0]) # define rs(x) (t[x][1]) int t[N][2],par[N],size[N],val[N],mx[N],add[N]; int root,tot; bool rev[N]; int check(int x){return rs(par[x])==x; } void up(int x) { size[x]=size[ls(x)]+size[rs(x)]+1; mx[x]=max(max(mx[ls(x)],mx[rs(x)]),val[x]); } void down(int x) { int &l=ls(x),&r=rs(x); if (add[x]) { if (l) add[l]+=add[x],val[l]+=add[x],mx[l]+=add[x]; if (r) add[r]+=add[x],val[r]+=add[x],mx[r]+=add[x]; } if (rev[x]) rev[l]^=1,rev[r]^=1,swap(l,r); rev[x]=add[x]=0; } void rotate(int x) { int y=par[x],k=check(x); down(x); down(y); t[y][k]=t[x][k^1]; par[t[x][k^1]]=y; t[x][k^1]=y; t[par[y]][check(y)]=x; par[x]=par[y];par[y]=x; up(y); up(x); } void splay(int x,int goal=0) { while (par[x]!=goal) { int y=par[x],z=par[y]; if (z!=goal) rotate((check(x)==check(y))?y:x); rotate(x); } if (!goal) root=x; } int build(int l,int r,int fa) { if(l>r) return 0; int x=++tot,mid=(l+r)>>1; par[x]=fa,val[x]=mx[x]=a[mid]; ls(x)=build(l,mid-1,x); rs(x)=build(mid+1,r,x); up(x); return x; } int kth_id(int k) { int cur=root; while (cur) { down(cur); if (t[cur][0] && k<=size[ls(cur)]) cur=ls(cur); else if (k>size[ls(cur)]+1) { k-=size[ls(cur)]+1; cur=rs(cur); } else return cur; } } int kth_val(int k) { int p=kth_id(k); return val[p]; } void reserve(int l,int r) { int x=kth_id(l),y=kth_id(r+2); splay(x); splay(y,x); rev[ls(y)]^=1; } int GetMax(int l,int r) { int x=kth_id(l),y=kth_id(r+2); splay(x); splay(y,x); return mx[ls(y)]; } int Add(int l,int r,int d) { int x=kth_id(l),y=kth_id(r+2); splay(x); splay(y,x); val[ls(y)]+=d; add[ls(y)]+=d; mx[ls(y)]+=d; } }tr; int main() { scanf("%d%d",&n,&m); memset(a,0,sizeof(a)); tr.mx[0]=a[1]=a[n+2]=-inf; tr.root=tr.build(1,n+2,0); while (m--) { int op,l,r,d; scanf("%d%d%d",&op,&l,&r); if (op==1) scanf("%d",&d),tr.Add(l,r,d); else if (op==2) tr.reserve(l,r); else if (op==3) printf("%d\n",tr.GetMax(l,r)); } return 0; }
二、Splay的启发式合并(按元素多少合并)
P3224 [HNOI2012]永无乡
维护两个操作:
一开始所有元素都独立。
1 . 合并x,y
2 . 在节点x连通块内求节点排名为k的元素编号。
Sol : 首先,对于每个元素都维护一棵平衡树,由于我们是数组模拟链表,所以在"指针池"中,只需要记录根是谁。
就可以维护一棵平衡树森林了,我们只要求在平衡树森林中总节点个数不超过"指针池"大小即可。
当合并节点x,y的时候,若以x所在节点的平衡树和y所在节点平衡树合并的时候,按节点个数合并即可,尽可能保证元素多的Splay树不改变,
然后把元素少的Splay树遍历依次,把所有拥有的元素插入到元素多的Splay树中
为了处理方便,我们令第i棵子树根的编号为father(i)
# include <bits/stdc++.h> using namespace std; const int N=3e5+10; map<int,int>Hash; int t[N][2],par[N],root[N],size[N],cnt[N],val[N]; int n,m,tot,f[N]; # define ls(x) t[x][0] # define rs(x) t[x][1] inline int read() { int X=0,w=0; char c=0; while(c<'0'||c>'9') {w|=c=='-';c=getchar();} while(c>='0'&&c<='9') X=(X<<3)+(X<<1)+(c^48),c=getchar(); return w?-X:X; } inline void write(int x) { if (x<0) putchar('-'),x=-x; if (x>9) write(x/10); putchar('0'+x%10); } int check(int x) { return rs(par[x])==x; } void up(int x){ size[x]=size[ls(x)]+size[rs(x)]+cnt[x]; } void rotate(int x){ int y=par[x],k=check(x); t[y][k]=t[x][k^1]; par[t[x][k^1]]=y; t[x][k^1]=y; t[par[y]][check(y)]=x; par[x]=par[y]; par[y]=x; up(y); up(x); } void splay(int x,int goal) { while (par[x]!=goal) { int y=par[x],z=par[y]; if (z!=goal) { if (check(x)==check(y)) rotate(y); else rotate(x); } rotate(x); } if (goal<=n) root[goal]=x; } void insert(int rt,int x) { int cur=root[rt],p=rt; while (cur && val[cur]!=x) p=cur,cur=t[cur][x>val[cur]]; if (cur) cnt[cur]++; else { cur=++tot; if (p>n) t[p][x>val[p]]=cur; ls(cur)=rs(cur)=0; par[cur]=p; val[cur]=x; size[cur]=cnt[cur]=1; } splay(cur,rt); } int kth(int rt,int k) { int cur=root[rt]; if (k>size[cur]||k<0) return -1; while (true) { if (t[cur][0]&&k<=size[ls(cur)]) cur=ls(cur); else if (k>size[ls(cur)]+cnt[cur]) { k-=size[ls(cur)]+cnt[cur]; cur=rs(cur); } else return val[cur]; } } int father(int x) { if (f[x]==x) return x; return f[x]=father(f[x]); } void dfs(int tr1,int tr2) { if (ls(tr1)) dfs(ls(tr1),tr2); if (rs(tr1)) dfs(rs(tr1),tr2); insert(tr2,val[tr1]); } void merge(int u,int v) { int fx=father(u),fy=father(v); if (size[root[fx]]>size[root[fy]]) swap(fx,fy); f[fx]=fy; dfs(root[fx],fy); } int main() { n=read();m=read(); for (int i=1;i<=n;i++) f[i]=i,root[i]=i+n; for (int i=1;i<=n;i++){ int x=read(); Hash[x]=i; val[i+n]=x; cnt[i+n]=size[i+n]=1; par[i+n]=i; } tot=n<<1; for (int i=1;i<=m;i++) { int u=read(),v=read(); merge(u,v); } int Q=read(); char op[3]; while (Q--) { scanf("%s",op);int x=read(),y=read(); if (op[0]=='B') merge(x,y); else { int ans=kth(father(x),y); if (ans==-1) write(-1),putchar('\n'); else write(Hash[ans]),putchar('\n'); } } return 0; }
三. 线段树+Splay 大常数树套树
【模板】二逼平衡树(树套树)
// luogu-judger-enable-o2 // luogu-judger-enable-o2 // luogu-judger-enable-o2 # pragma GCC optimize(3) # include <cstdio> # define il inline # define Rint register int using namespace std; const int inf=2147483647; const int N=5e4+5; int a[N]; int n,m,Lim; il int max(int a,int b) { return (a<b)?b:a; } il int min(int a,int b) { return (a>b)?b:a; } il int read() { int X=0,w=0; char c=0; while(c<'0'||c>'9') {w|=c=='-';c=getchar();} while(c>='0'&&c<='9') X=(X<<3)+(X<<1)+(c^48),c=getchar(); return w?-X:X; } il void write(int x) { if (x<0) putchar('-'),x=-x; if (x>9) write(x/10); putchar('0'+x%10); } il void writeln(int x){write(x);putchar('\n');} struct SegmentTree{ int l,r,root; }tr[N<<2]; struct Splay{ # define ls(x) (t[x][0]) # define rs(x) (t[x][1]) int t[N*50][2],cnt[N*50],val[N*50],size[N*50],par[N*50]; int tot; Splay() { tot=0; } il int check(int x) { return rs(par[x])==x; } il void up(int x){ size[x]=size[ls(x)]+size[rs(x)]+cnt[x]; } il void rotate(int x){ int y=par[x],k=check(x); t[y][k]=t[x][k^1]; par[t[x][k^1]]=y; t[x][k^1]=y; t[par[y]][check(y)]=x; par[x]=par[y]; par[y]=x; up(y); up(x); } il void splay(int node,int x,int goal) { while (par[x]!=goal) { int y=par[x],z=par[y]; if (z!=goal) { if (check(x)==check(y)) rotate(y); else rotate(x); } rotate(x); } if (!goal) tr[node].root=x; } il void insert(int node,int x) { int cur=tr[node].root,p=0; if (!cur) { cur=++tot; tr[node].root=cur; ls(cur)=rs(cur)=0; par[cur]=0; val[cur]=x; size[cur]=cnt[cur]=1; return; } while (cur && val[cur]!=x) p=cur,cur=t[cur][x>val[cur]]; if (cur) cnt[cur]++; else { cur=++tot; if (p) t[p][x>val[p]]=cur; ls(cur)=rs(cur)=0; par[cur]=p; val[cur]=x; size[cur]=cnt[cur]=1; } splay(node,cur,0); } il void find(int node,int x){ int cur=tr[node].root; if (!cur) return; while (t[cur][x>val[cur]] && val[cur]!=x) cur=t[cur][x>val[cur]]; splay(node,cur,0); } il int pre_id(int node,int x) { find(node,x); int cur=tr[node].root; if (val[cur]<x) return cur; cur=ls(cur); while (rs(cur)) cur=rs(cur); return cur; } il int pre_val(int node,int x){ find(node,x); int cur=tr[node].root; if (val[cur]<x) return val[cur]; cur=ls(cur); while (rs(cur)) cur=rs(cur); return val[cur]; } il int suc_id(int node,int x) { find(node,x); int cur=tr[node].root; if (val[cur]>x) return cur; cur=rs(cur); while (ls(cur)) cur=ls(cur); return cur; } il int suc_val(int node,int x){ find(node,x); int cur=tr[node].root; if (val[cur]>x) return val[cur]; cur=rs(cur); while (ls(cur)) cur=ls(cur); return val[cur]; } il void erase(int node,int x){ int last=pre_id(node,x); int next=suc_id(node,x); splay(node,next,0); splay(node,last,next); int d=rs(last); if (cnt[d]>1) cnt[d]--,splay(node,d,0); else rs(last)=0; up(last); } il int rank(int node,int x) { find(node,x); int cur=tr[node].root; if (val[cur]>=x) return size[ls(cur)]-1; else return size[ls(cur)]+cnt[cur]-1; } # undef ls # undef rs }t; # define lson 2*x,l,mid # define rson 2*x+1,mid+1,r # define mid ((l+r)>>1) int Ans; il void SegBuild(int x,int l,int r) { t.insert(x,-inf); t.insert(x,inf); if (l==r) return; SegBuild(lson); SegBuild(rson); } il void SegInsert(int x,int l,int r,int pos,int val) { t.insert(x,val); if (l==r) return; if (pos<=mid) SegInsert(lson,pos,val); else SegInsert(rson,pos,val); } il void SegRank(int x,int l,int r,int val,int opl,int opr) { if (opl<=l&&r<=opr) { Ans+=t.rank(x,val); return; } if (opl<=mid) SegRank(lson,val,opl,opr); if (opr>mid) SegRank(rson,val,opl,opr); } il void SegUpdate(int x,int l,int r,int pos,int val) { t.erase(x,a[pos]); t.insert(x,val); if (l==r) { a[pos]=val; return; } if (pos<=mid) SegUpdate(lson,pos,val); else SegUpdate(rson,pos,val); } il void SegPre(int x,int l,int r,int opl,int opr,int val) { if (opl<=l&&r<=opr) { Ans=max(Ans,t.pre_val(x,val)); return; } if (opl<=mid) SegPre(lson,opl,opr,val); if (opr>mid) SegPre(rson,opl,opr,val); } il void SegSuc(int x,int l,int r,int opl,int opr,int val) { if (opl<=l&&r<=opr) { Ans=min(t.suc_val(x,val),Ans); return; } if (opl<=mid) SegSuc(lson,opl,opr,val); if (opr>mid) SegSuc(rson,opl,opr,val); } il int SegKth(int opl,int opr,int k) { int l=0,r=Lim,ans; while (l<=r) { int Mid=(l+r)>>1; Ans=0; SegRank(1,1,n,Mid,opl,opr); if (Ans+1>k) r=Mid-1; else l=Mid+1,ans=Mid; } return ans; } int main() { n=read();m=read(); SegBuild(1,1,n); for (Rint i=1;i<=n;i++) { a[i]=read(); Lim=max(Lim,a[i]); SegInsert(1,1,n,i,a[i]); } for (Rint i=1;i<=m;i++) { int op=read(),l=read(),r=read(),k; if (op==1) k=read(),Ans=0,SegRank(1,1,n,k,l,r),writeln(Ans+1); if (op==2) k=read(),writeln(SegKth(l,r,k)); if (op==3) SegUpdate(1,1,n,l,r); if (op==4) k=read(),Ans=-inf,SegPre(1,1,n,l,r,k),writeln(Ans); if (op==5) k=read(),Ans=inf,SegSuc(1,1,n,l,r,k),writeln(Ans); } return 0; }
# include <bits/stdc++.h> # define inf (2147483647) using namespace std; const int N=4e6+10; int n,m,Lim,tot; int a[N]; struct Treap_Node{ int val,key,cnt,size; int ch[2]; }t[N]; struct Treap { # define ls t[x].ch[0] # define rs t[x].ch[1] void up(int &x){t[x].size=t[x].cnt+t[ls].size+t[rs].size;} void rotate(int &x,int d) { int son=t[x].ch[d]; t[x].ch[d]=t[son].ch[d^1]; t[son].ch[d^1]=x; up(x); up(x=son); } void insert(int &x,int val) { if (!x) { x=++tot; t[x].size=t[x].cnt=1; t[x].val=val; t[x].key=rand(); return; } t[x].size++; if (t[x].val==val) { t[x].cnt++; return;} int d=t[x].val<val; insert(t[x].ch[d],val); if (t[x].key>t[t[x].ch[d]].key) rotate(x,d); } void erase(int &x,int val) { if (!x) return; if (t[x].val==val) { if (t[x].cnt>1) { t[x].cnt--; t[x].size--; return;} int d=t[ls].key>t[rs].key; if (ls==0||rs==0) x=ls+rs; else rotate(x,d),erase(x,val); } else t[x].size--,erase(t[x].ch[t[x].val<val],val); } inline int rank(int &x,int val){ if (!x) return 0; if (t[x].val==val) return t[ls].size; if (t[x].val>val) return rank(ls,val); return t[x].cnt+t[ls].size+rank(rs,val); } int find(int &rt,int k) { int x=rt; while (1) { if (k<=t[ls].size) x=ls; else if (k>t[x].cnt+t[ls].size) k-=t[x].cnt+t[ls].size,x=rs; else return t[x].val; } } inline int pre(int &x,int val) { if (!x) return -inf; if (t[x].val>=val) return pre(ls,val); return max(pre(rs,val),t[x].val); } int nex(int &x,int val) { if (!x) return inf; if (t[x].val<=val) return nex(rs,val); return min(nex(ls,val),t[x].val); } #undef ls #undef rs }tr; struct Segment_Tree{ int l,r,root; }tree[N]; # define lson (x<<1),l,mid # define rson (x<<1)+1,mid+1,r # define mid ((l+r)>>1) void SegInsert(int x,int l,int r,int pos,int val) { tr.insert(tree[x].root,val); if (l==r) return; if (pos<=mid) SegInsert(lson,pos,val); else SegInsert(rson,pos,val); } int SegRank(int x,int l,int r,int ql,int qr,int val) { if (ql<=l&&r<=qr) return tr.rank(tree[x].root,val); int ret=0; if (ql<=mid) ret+=SegRank(lson,ql,qr,val); if (qr>mid) ret+=SegRank(rson,ql,qr,val); return ret; } void SegUpdate(int x,int l,int r,int pos,int val) { tr.insert(tree[x].root,val); tr.erase(tree[x].root,a[pos]); if (l==r) { a[pos]=val; return;} if (pos<=mid) SegUpdate(lson,pos,val); else SegUpdate(rson,pos,val); } int SegPre(int x,int l,int r,int ql,int qr,int val) { if (ql<=l&&r<=qr) return tr.pre(tree[x].root,val); int ret=-inf; if (ql<=mid) ret=max(ret,SegPre(lson,ql,qr,val)); if (qr>mid) ret=max(ret,SegPre(rson,ql,qr,val)); return ret; } int SegSuc(int x,int l,int r,int ql,int qr,int val) { if (ql<=l&&r<=qr) return tr.nex(tree[x].root,val); int ret=inf; if (ql<=mid) ret=min(ret,SegSuc(lson,ql,qr,val)); if (qr>mid) ret=min(ret,SegSuc(rson,ql,qr,val)); return ret; } # undef lson # undef rson # undef mid int SegKth(int l,int r,int k) { int L=0,R=Lim+1,Ans; while (L<=R) { int Mid=(L+R)>>1; if (SegRank(1,1,n,l,r,Mid)+1>k) R=Mid-1; else Ans=Mid,L=Mid+1; } return Ans; } int main() { scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) { int t; scanf("%d",&t); Lim=max(Lim,t); a[i]=t; SegInsert(1,1,n,i,t); } while (m--) { int op,l,r,k; scanf("%d",&op); if (op==1) scanf("%d%d%d",&l,&r,&k),printf("%d\n",SegRank(1,1,n,l,r,k)+1); if (op==2) scanf("%d%d%d",&l,&r,&k),printf("%d\n",SegKth(l,r,k)); if (op==3) scanf("%d%d",&l,&r),SegUpdate(1,1,n,l,r); if (op==4) scanf("%d%d%d",&l,&r,&k),printf("%d\n",SegPre(1,1,n,l,r,k)); if (op==5) scanf("%d%d%d",&l,&r,&k),printf("%d\n",SegSuc(1,1,n,l,r,k)); } return 0; }