BZOJ 3159: 决战 LCT+Splay
挺厉害的一道大数据结构题.
由于 LCT 是维护树的形态的,所以说不支持翻转操作.
而在维护序列时 splay 是支持区间翻转的.
所以,我们对于 LCT 中每一个重链都维护一个 splay(这个不同于 LCT 中的 splay)
由于重链是一个序列,所以是支持序列的区间翻转的.
那么我们的翻转,链加和,链求和的操作就都在这个重链对应的 splay 上进行.
然后这里一定要注意:我们在 LCT 中维护 LCT 中每个点对应到 splay 上的编号,只有 LCT 中的 splay 的根节点对应的是正确的编号.
#include <cstdio> #include <string> #include <vector> #include <cstring> #include <algorithm> #define N 50007 #define ll long long using namespace std; namespace IO { void setIO(string s) { string in=s+".in"; string out=s+".out"; freopen(in.c_str(),"r",stdin); freopen(out.c_str(),"w",stdout); } }; namespace Splay { #define lson s[x].ch[0] #define rson s[x].ch[1] struct node { int ch[2],f,rev,size; ll add,val,sum,Min,Max; }s[N]; int sta[N]; int get(int x) { return s[s[x].f].ch[1]==x; } void mark_rev(int x) { s[x].rev^=1,swap(lson,rson);} void mark_add(int x,ll v) { s[x].add+=v; s[x].sum+=1ll*s[x].size*v; s[x].Min+=v,s[x].Max+=v,s[x].val+=v; } void pushup(int x) { s[x].sum=s[x].Min=s[x].Max=s[x].val; s[x].size=s[lson].size+s[rson].size+1; if(lson) { s[x].sum+=s[lson].sum; s[x].Min=min(s[x].Min,s[lson].Min); s[x].Max=max(s[x].Max,s[lson].Max); } if(rson) { s[x].sum+=s[rson].sum; s[x].Min=min(s[x].Min,s[rson].Min); s[x].Max=max(s[x].Max,s[rson].Max); } } void pushdown(int x) { if(s[x].rev) { if(lson) mark_rev(lson); if(rson) mark_rev(rson); s[x].rev=0; } if(s[x].add) { if(lson) mark_add(lson,s[x].add); if(rson) mark_add(rson,s[x].add); s[x].add=0; } } void rotate(int x) { int old=s[x].f,fold=s[old].f,which=get(x); s[old].ch[which]=s[x].ch[which^1]; if(s[old].ch[which]) s[s[old].ch[which]].f=old; s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold; if(fold) s[fold].ch[s[fold].ch[1]==old]=x; pushup(old),pushup(x); } void splay(int x) { int fa,v=0,tmp=x; for(;tmp;tmp=s[tmp].f) sta[++v]=tmp; for(;v;--v) pushdown(sta[v]); for(;fa=s[x].f;rotate(x)) if(s[fa].f) rotate(get(fa)==get(x)?fa:x); } int get_kth(int x,int kth) { pushdown(x); if(kth<=s[lson].size) return get_kth(lson,kth); else if(s[lson].size+1==kth) return x; else return get_kth(rson,kth-s[lson].size-1); } int findrt(int x) { while(s[x].f) { x=s[x].f; } return x; } #undef lson #undef rson }; #define ls s[x].ch[0] #define rs s[x].ch[1] struct node { int ch[2],f,rev,size; }s[N]; int sta[N],rt[N]; int get(int x) { return s[s[x].f].ch[1]==x; } int Isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x; } void mark(int x) { swap(ls,rs), s[x].rev^=1; } void pushup(int x) { s[x].size=s[ls].size+s[rs].size+1; } void pushdown(int x) { if(s[x].rev) { s[x].rev=0; if(ls) mark(ls); if(rs) mark(rs); } } void rotate(int x) { int old=s[x].f,fold=s[old].f,which=get(x); if(!Isr(old)) s[fold].ch[s[fold].ch[1]==old]=x; s[old].ch[which]=s[x].ch[which^1]; if(s[old].ch[which]) s[s[old].ch[which]].f=old; s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold; pushup(old),pushup(x); } void splay(int x) { int u=x,v=0,fa; for(sta[++v]=u;!Isr(u);u=s[u].f) sta[++v]=s[u].f; rt[x]=rt[u]; for(;v;--v) pushdown(sta[v]); for(u=s[u].f;(fa=s[x].f)!=u;rotate(x)) if(s[fa].f!=u) rotate(get(fa)==get(x)?fa:x); } void Access(int x) { for(int y=0;x;y=x,x=s[x].f) { splay(x); if(rs) // cut { rt[x]=Splay::get_kth(rt[x],s[ls].size+1); Splay::splay(rt[x]); rt[rs]=Splay::s[rt[x]].ch[1]; Splay::s[rt[rs]].f=0; Splay::s[rt[x]].ch[1]=0; Splay::pushup(rt[x]); } if(y) // link { rt[x]=Splay::get_kth(rt[x],s[ls].size+1); Splay::splay(rt[x]); Splay::s[rt[x]].ch[1]=rt[y]; Splay::s[rt[y]].f=rt[x]; Splay::pushup(rt[x]); } rs=y; pushup(x); } } void makeroot(int x) { Access(x),splay(x),mark(x),Splay::mark_rev(rt[x]); } void split(int x,int y) { makeroot(x),Access(y),splay(y); } #undef ls #undef rs int edges; int hd[N],to[N<<1],nex[N<<1]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u,int ff) { s[u].f=ff; rt[u]=u; s[u].size=1; Splay::s[rt[u]].size=1; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v,u); } } int main() { // IO::setIO("input"); int i,j,n,m,R; scanf("%d%d%d",&n,&m,&R); for(i=1;i<n;++i) { int x,y; scanf("%d%d",&x,&y); add(x,y),add(y,x); } dfs(1,0); for(i=1;i<=m;++i) { char op[10]; int x,y,z; scanf("%s",op+1); if(op[3]=='c') { scanf("%d%d%d",&x,&y,&z); split(x,y); Splay::mark_add(rt[y],(ll)z); } if(op[3]=='m') { scanf("%d%d",&x,&y); split(x,y); printf("%lld\n",Splay::s[rt[y]].sum); } if(op[3]=='j') { scanf("%d%d",&x,&y); split(x,y); printf("%lld\n",Splay::s[rt[y]].Max); } if(op[3]=='n') { scanf("%d%d",&x,&y); split(x,y); printf("%lld\n",Splay::s[rt[y]].Min); } if(op[3]=='v') { scanf("%d%d",&x,&y); split(x,y); Splay::mark_rev(rt[y]); } } return 0; }