[ARC148C] Lights Out on Tree 题解
在考场遇到了这道题,感觉很有意思。
当时直接想到的就是虚树,可惜打挂了。
后来改对了,写篇题解纪念一下。
首先看到 \(\sum M_i\le 2\times 10^5\),很容易想到虚树的数据范围。
我们设 \(dp_i,fg_i\) 表示将 \(i\) 的子树全部染白或染黑需要多少次,\(vis_i=1/2\) 表示该点为黑 \(/\) 白,则有:
\[\begin{cases}\begin{cases}
dp_i=\sum fg_j+1\\
fg_i=\sum fg_j
\end{cases}(vis_x=1)\\\\\begin{cases}
dp_i=\sum dp_j\\
fg_i=\sum dp_j+1\end{cases}(vis_x=2)\end{cases}
\]
很容易发现,假如在虚树上的父子在原树上且父亲为黑点,则有:
\[\begin{cases}
dp_i=\sum dp_j+sz_i+1\\
fg_i=\sum dp_j+sz_i
\end{cases}
\]
其中 \(sz_i\) 表示节点儿子数。加之在虚树上,儿子不一定都在,所以原先的方程改为:
\[\begin{cases}\begin{cases}
\begin{cases}
dp_i=\sum (fg_j-1)+sz_i+1\\
fg_i=\sum (fg_j-1)+sz_i
\end{cases}(fa_j=i)\\\\
\begin{cases}
dp_i=\sum dp_j+sz_i+1\\
fg_i=\sum dp_j+sz_i
\end{cases}(fa_j\ne i)
\end{cases}(vis_x=1)\\\\
\begin{cases}dp_i=\sum dp_j\\fg_i=\sum dp_j+1\end{cases}(vis_x=2)\end{cases}
\]
我用 \(\text{lca}\) 建虚树,时间复杂度 \(O(n\log n)\)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=4e5+5;
int n,q,rt,d,dep[N],c[N];
int id,f[N][25],dfn[N];
int dp[N],fg[N],vis[N];
vector<int>g[N],ve[N];
int cmp(int x,int y){
return dfn[x]<dfn[y];
}void dfs(int x,int fa){
dep[x]=dep[fa]+1;
f[x][0]=fa;dfn[x]=++id;
for(int i=0;i<24;i++)
f[x][i+1]=f[f[x][i]][i];
for(int i=0;i<g[x].size();i++)
dfs(g[x][i],x);
}int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=24;~i;i--)
if(dep[x]-dep[y]>=(1<<i))
x=f[x][i];
if(x==y) return x;
for(int i=24;~i;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}void build(){
if(!vis[1]) vis[c[++d]=1]=2;
int dd=d;sort(c+1,c+d+1,cmp);
for(int i=1;i<dd;i++){
int lc=lca(c[i],c[i+1]);
if(vis[lc]) continue;
vis[lc]=2;c[++d]=lc;
}sort(c+1,c+d+1,cmp);
for(int i=1;i<d;i++){
int x=c[i],y=c[i+1];
int lc=lca(x,y);
ve[lc].push_back(y);
}
}void dp_(int x){
for(int i=0;i<ve[x].size();i++){
int y=ve[x][i];dp_(y);
if(vis[x]==1){
if(f[y][0]==x)
fg[x]+=fg[y]-1,dp[x]+=fg[y]-1;
else fg[x]+=dp[y],dp[x]+=dp[y];
continue;
}dp[x]+=dp[y];fg[x]+=dp[y];
}if(vis[x]==1){
dp[x]+=g[x].size();
fg[x]+=g[x].size();
dp[x]+=1;
}else fg[x]+=1;
}int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>n>>q;
for(int i=2;i<=n;i++){
int fa;cin>>fa;
g[fa].push_back(i);
}dfs(1,0);
while(q--){
for(int i=1;i<=d;i++){
ve[c[i]].clear();
dp[c[i]]=fg[c[i]]=0;
vis[c[i]]=0;
}cin>>d;
for(int i=1;i<=d;i++)
cin>>c[i],vis[c[i]]=1;
build();dp_(1);
cout<<dp[1]<<"\n";
}return 0;
}