[SHOI2014]三叉神经树 题解
LCT
Statement
一颗节点数为 \(3n+1\) 的树,编号在 \(1 \dots n\) 的节点有且仅有三个儿子。
其余点没有儿子。所有节点值只可能为 \(0\) 或 $ 1$,编号在 \(n + 1 \dots 3n\) 的节点的值由输入确定,编号在 $ 1 \to n$ 的节点的值为三个儿子中值数量更多的那种。
\(m\) 次操作,每次会改变一个 \(n + 1\dots 3n\) 的节点的值,请每次操作结束后输出根节点的值
\(n,m\le 5\times 10^5\)
Solution
设 \(sum\) 表示一个节点三个儿子的权值和
我们可以发现这样一个性质,我们每次修改一个叶子的权值的时候,只会影响到从叶子向上的一条链,这条链不一定到根
比如我们把一个 \(0\) 改成 \(1\) ,那么对应有影响的只有从叶子的父亲开始 \(sum=1\) 的一个连续的链
对应的,我们将 \(1\) 改成 \(0\) ,只会是一条链上的所有的 \(sum-1\)
考虑 LCT。
现在的问题变成了如何找到这两个点,即最深的 \(sum\neq 1\) 和 \(sum\neq 2\) 的点,找到后,只需要把这个点 splay 上去后将他的右子树整体打一个 tag 即可
不妨直接在每一个节点上维护 \(neq[1/2]\) 表示子树中(链)最深的不为 \(1/2\) 的位置,\(pushup\) 函数大致长这样:(跟 FlashHu 学的压行)
void pushup(int rt){//因为 splay 中序遍历后才是原链,所以右子树更深
if(!(t[rt].neq[1]=t[rs].neq[1])&&!(t[rt].neq[1]=rt*(t[rt].sum!=1)))t[rt].neq[1]=t[ls].neq[1];
if(!(t[rt].neq[2]=t[rs].neq[2])&&!(t[rt].neq[2]=rt*(t[rt].sum!=2)))t[rt].neq[2]=t[ls].neq[2];
}
这样的话我们的复杂度就是 \(O(n\log n)\) ,细节在代码中
Code
#include<bits/stdc++.h>
#define ls t[rt].ch[0]
#define rs t[rt].ch[1]
#define min(a,b) ((a)<(b)?(a):(b))
#define max(a,b) ((a)>(b)?(a):(b))
#define swap(x,y) x^=y^=x^=y
using namespace std;
const int N = 2e6+5;
const int inf = 1e9;
char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
int read(){
int s=0,w=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
return s*w;
}
struct Tree{
int f,tg,val,sum;
int ch[2],neq[3];
}t[N];
int stk[N],top;
void pushup(int rt){//因为 splay 中序遍历后才是原链,所以右子树更深
if(!(t[rt].neq[1]=t[rs].neq[1])&&!(t[rt].neq[1]=rt*(t[rt].sum!=1)))t[rt].neq[1]=t[ls].neq[1];
if(!(t[rt].neq[2]=t[rs].neq[2])&&!(t[rt].neq[2]=rt*(t[rt].sum!=2)))t[rt].neq[2]=t[ls].neq[2];
}
void add(int rt,int v){
t[rt].sum+=v,t[rt].val=t[rt].sum>1;
swap(t[rt].neq[1],t[rt].neq[2]),t[rt].tg+=v;
/*
容易理解这里可以直接交换的原因,以 v=1 为例子
+ 本来最深的不为 1 的点记录的应该是值为 2 的点(不存在值为 3 的点,不然你改毛线)
+ 本来最深的不为 2 的点记录的应该是值为 1 的点
*/
}
void pushdown(int rt){
if(!t[rt].tg)return ;
if(ls)add(ls,t[rt].tg);
if(rs)add(rs,t[rt].tg);
t[rt].tg=0;
}
bool identity(int rt){return t[t[rt].f].ch[1]==rt;}
bool check(int rt){return t[t[rt].f].ch[0]!=rt&&t[t[rt].f].ch[1]!=rt;}
void rotate(int rt){
int f=t[rt].f,ff=t[f].f,op=identity(rt),ch=t[rt].ch[op^1];
t[ch].f=f,t[f].ch[op]=ch,t[rt].ch[op^1]=f,t[rt].f=ff;
if(!check(f))t[ff].ch[identity(f)]=rt;
t[f].f=rt,pushup(f),pushup(rt);
}
void splay(int rt){
int u=rt; stk[++top]=u;
while(!check(u))stk[++top]=(u=t[u].f);
while(top)pushdown(stk[top--]);
while(!check(rt)){
int f=t[rt].f;
if(!check(f))
rotate(identity(f)==identity(rt)?f:rt);
rotate(rt);
}
}
void access(int rt){
for(int p=0;rt;p=rt,rt=t[rt].f)
splay(rt),rs=p,pushup(rt);
}
vector<int>Edge[N];
int n,m,ans;
void dfs(int u,int fath){//dfs 求出最开始的值
for(auto v:Edge[u])if(v^fath)
dfs(v,u),t[u].sum+=t[v].val;
if(u<=n)t[u].val=t[u].sum>1;
}
signed main(){
n=read();
for(int i=1;i<=n;++i)
for(int j=1,v;j<=3;++j)
v=read(),t[v].f=i,
Edge[i].push_back(v),
Edge[v].push_back(i);
for(int i=n+1;i<=3*n+1;++i)t[i].val=read();
dfs(1,0),m=read(),ans=t[1].val;
while(m--){
int pos=read(),f=t[pos].f,tg=t[pos].val?-1:1;
access(f),splay(f);//注意不是对 pos 操作
int neq=t[f].neq[t[pos].val?2:1];
if(neq){
splay(neq);
add(t[neq].ch[1],tg),pushup(t[neq].ch[1]);
t[neq].sum+=tg,t[neq].val=t[neq].sum>1,pushup(neq);//不要忘了单点修
}else ans^=1,add(f,tg),pushup(f);//否则直接修改整条链
t[pos].val^=1;
printf("%d\n",ans);
}
return 0;
}