[SHOI2014]三叉神经树 题解

LCT

Statement

一颗节点数为 \(3n+1\) 的树,编号在 \(1 \dots n\) 的节点有且仅有三个儿子。

其余点没有儿子。所有节点值只可能为 \(0\) 或 $ 1$,编号在 \(n + 1 \dots 3n\) 的节点的值由输入确定,编号在 $ 1 \to n$ 的节点的值为三个儿子中值数量更多的那种。

\(m\) 次操作,每次会改变一个 \(n + 1\dots 3n\) 的节点的值,请每次操作结束后输出根节点的值

\(n,m\le 5\times 10^5\)

Solution

\(sum\) 表示一个节点三个儿子的权值和

我们可以发现这样一个性质,我们每次修改一个叶子的权值的时候,只会影响到从叶子向上的一条链,这条链不一定到根

比如我们把一个 \(0\) 改成 \(1\) ,那么对应有影响的只有从叶子的父亲开始 \(sum=1\) 的一个连续的链

对应的,我们将 \(1\) 改成 \(0\) ,只会是一条链上的所有的 \(sum-1\)

考虑 LCT。

现在的问题变成了如何找到这两个点,即最深的 \(sum\neq 1\)\(sum\neq 2\) 的点,找到后,只需要把这个点 splay 上去后将他的右子树整体打一个 tag 即可

不妨直接在每一个节点上维护 \(neq[1/2]\) 表示子树中(链)最深的不为 \(1/2\) 的位置,\(pushup\) 函数大致长这样:(跟 FlashHu 学的压行)

void pushup(int rt){//因为 splay 中序遍历后才是原链,所以右子树更深
    if(!(t[rt].neq[1]=t[rs].neq[1])&&!(t[rt].neq[1]=rt*(t[rt].sum!=1)))t[rt].neq[1]=t[ls].neq[1];
    if(!(t[rt].neq[2]=t[rs].neq[2])&&!(t[rt].neq[2]=rt*(t[rt].sum!=2)))t[rt].neq[2]=t[ls].neq[2];
}

这样的话我们的复杂度就是 \(O(n\log n)\) ,细节在代码中

Code

#include<bits/stdc++.h>
#define ls t[rt].ch[0]
#define rs t[rt].ch[1]
#define min(a,b) ((a)<(b)?(a):(b))
#define max(a,b) ((a)>(b)?(a):(b))
#define swap(x,y) x^=y^=x^=y
using namespace std;
const int N = 2e6+5;
const int inf = 1e9;

char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
int read(){
    int s=0,w=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
    while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
    return s*w;
}

struct Tree{
    int f,tg,val,sum;
    int ch[2],neq[3];
}t[N];
int stk[N],top;

void pushup(int rt){//因为 splay 中序遍历后才是原链,所以右子树更深
    if(!(t[rt].neq[1]=t[rs].neq[1])&&!(t[rt].neq[1]=rt*(t[rt].sum!=1)))t[rt].neq[1]=t[ls].neq[1];
    if(!(t[rt].neq[2]=t[rs].neq[2])&&!(t[rt].neq[2]=rt*(t[rt].sum!=2)))t[rt].neq[2]=t[ls].neq[2];
}
void add(int rt,int v){
    t[rt].sum+=v,t[rt].val=t[rt].sum>1;
    swap(t[rt].neq[1],t[rt].neq[2]),t[rt].tg+=v;
    /*
    容易理解这里可以直接交换的原因,以 v=1 为例子
    + 本来最深的不为 1 的点记录的应该是值为 2 的点(不存在值为 3 的点,不然你改毛线)
    + 本来最深的不为 2 的点记录的应该是值为 1 的点
    */
}
void pushdown(int rt){
    if(!t[rt].tg)return ;
    if(ls)add(ls,t[rt].tg);
    if(rs)add(rs,t[rt].tg);
    t[rt].tg=0;
}
bool identity(int rt){return t[t[rt].f].ch[1]==rt;}
bool check(int rt){return t[t[rt].f].ch[0]!=rt&&t[t[rt].f].ch[1]!=rt;}
void rotate(int rt){
    int f=t[rt].f,ff=t[f].f,op=identity(rt),ch=t[rt].ch[op^1];
    t[ch].f=f,t[f].ch[op]=ch,t[rt].ch[op^1]=f,t[rt].f=ff;
    if(!check(f))t[ff].ch[identity(f)]=rt;
    t[f].f=rt,pushup(f),pushup(rt);
}
void splay(int rt){
    int u=rt; stk[++top]=u;
    while(!check(u))stk[++top]=(u=t[u].f);
    while(top)pushdown(stk[top--]);
    while(!check(rt)){
        int f=t[rt].f;
        if(!check(f))
            rotate(identity(f)==identity(rt)?f:rt);
        rotate(rt);
    }
}
void access(int rt){
    for(int p=0;rt;p=rt,rt=t[rt].f)
        splay(rt),rs=p,pushup(rt);
}

vector<int>Edge[N];
int n,m,ans;

void dfs(int u,int fath){//dfs 求出最开始的值
    for(auto v:Edge[u])if(v^fath)
        dfs(v,u),t[u].sum+=t[v].val;
    if(u<=n)t[u].val=t[u].sum>1;
}

signed main(){
    n=read();
    for(int i=1;i<=n;++i)
        for(int j=1,v;j<=3;++j)
            v=read(),t[v].f=i,
            Edge[i].push_back(v),
            Edge[v].push_back(i);
    for(int i=n+1;i<=3*n+1;++i)t[i].val=read();
    dfs(1,0),m=read(),ans=t[1].val;
    while(m--){
        int pos=read(),f=t[pos].f,tg=t[pos].val?-1:1;
        access(f),splay(f);//注意不是对 pos 操作
        int neq=t[f].neq[t[pos].val?2:1];
        if(neq){
            splay(neq);
            add(t[neq].ch[1],tg),pushup(t[neq].ch[1]);
            t[neq].sum+=tg,t[neq].val=t[neq].sum>1,pushup(neq);//不要忘了单点修
        }else ans^=1,add(f,tg),pushup(f);//否则直接修改整条链
        t[pos].val^=1;
        printf("%d\n",ans);
    }
    return 0;
}
posted @ 2022-02-20 20:12  _Famiglistimo  阅读(117)  评论(0编辑  收藏  举报