P7163 [COCI2020-2021#2] Svjetlo
题意
给你一棵点权是 \(0/1\) 的树,你可以从任意一点开始,走到任意一点结束,每到达一个点,都要翻转当前的点权。给定初始的点权,求使得整棵树的点权都变成 \(1\) 的最短路径长度。
Solution
乍一看以为是个换根。。。看题解发现自己 naive 了。
对于求树上最优路径的问题,可以考虑两端是否在一棵子树中。我们令 \(dp_{i,0/1/2,0/1}\) 表示在 \(i\) 的子树中有整个路径的 \(j\) 个端点,并且走完之后 \(x\) 是 \(0/1\) 的点权,此时让整个子树都变成 \(1\) 的最短路径。
为了方便转移,这里的路径采取左闭右开,也就是说在 \(0,2\) 状态中,我们只算入进入根的,不算入出根的。
对于没有端点在子树内的,说明它是从外头进来,然后在子树里捣鼓一圈后又出去。这可以直接从儿子的状态转移过来。然后考虑从当前根走入儿子,然后又出来,回到根,这样根和儿子的状态都变了。那如果想儿子状态翻转,最优的策略就是再从根走到儿子,然后回到根。即:
对于只有一个端点在子树内的,说明它是从外头进来然后不出去了。这也就意味着,当前根的某一个子树中有一个是含有一个端点的。那么合并一个子树的时候,可能这个端点是在新的子树中的,或者是在原来的子树中的,取最小值即可。即:
对于两个端点都在子树内的,需要更多的讨论。合并两树 \(i,s\) 的时候,有:
- \(i\) 中两个端点,\(s\) 中没有端点;
- \(s\) 中两个端点,\(i\) 中没有端点;
- \(i,s\) 中各一个端点。
对于上面的我们各写出转移是这样的:
-
\[dp'_{i,2,0}=\min(dp_{i,2,1}+dp_{s,0,0}+2,dp_{i,2,0}+dp_{s,0,1}+4)\\dp'_{i,2,1}=\min(dp_{i,2,0}+dp_{s,0,0}+2,dp_{i,2,1}+dp_{s,0,1}+4) \]
-
\[dp'_{i,2,0}=\min(dp_{i,0,1}+dp_{s,2,0}+2,dp_{i,0,0}+dp_{s,2,1}+4)\\dp'_{i,2,1}=\min(dp_{i,0,0}+dp_{s,2,0}+2,dp_{i,0,1}+dp_{s,2,1}+4) \]
-
\[dp'_{i,2,0}=\min(dp_{i,1,0}+dp_{s,1,1},dp_{i,1,1}+dp_{s,1,0}+2)\\dp'_{i,2,1}=\min(dp_{i,1,1}+dp_{s,1,1},dp_{i,1,0}+dp_{s,1,0}+2) \]
然后每类取 \(\min\) 就可以了。
对于每个节点,其子树内的某一个端点可能在当前子树的根上,但是我们在上面并没有计入。对于子树中有一个端点,并且它在根的情况,有:
对于子树中有两个端点,并且有在根的情况,有:
边界条件:\(dp_{i,0,c}=0\)。最终答案就是 \(dp_{1,2,1}\)。
细节:
要从初值为 \(0\) 的点开始遍历,并且如果一个子树中都是 \(1\),那么不用进入这个子树,否则会多算。
Code
#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
#define pb emplace_back
#define pii pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define all(a) (a).begin(),(a).end()
#define siz(a) (int)(a).size()
#define clr(a) memset(a,0,sizeof(a))
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define pt(a) cerr<<#a<<'='<<a<<' '
#define pts(a) cerr<<#a<<'='<<a<<'\n'
//#define int long long
using namespace std;
const int MAXN=5e5+10;
int c[MAXN],dp[MAXN][3][2],tmp[3][2],flag[MAXN];
vector<int> e[MAXN];
void pdfs(int x,int fa){
flag[x]=c[x];
for(int s:e[x]){
if(s==fa) continue;
pdfs(s,x);
flag[x]&=flag[s];
}
}
void dfs(int x,int fa){
dp[x][0][c[x]]=0;
for(int s:e[x]){
if(s==fa) continue;
if(flag[s]) continue;
dfs(s,x);
rep(i,0,2) rep(j,0,1) tmp[i][j]=dp[x][i][j];
memset(dp[x],0x3f,sizeof(dp[x]));//Important
//In case trans from itself
dp[x][0][0]=min(tmp[0][1]+dp[s][0][0]+2,tmp[0][0]+dp[s][0][1]+4);
dp[x][0][1]=min(tmp[0][0]+dp[s][0][0]+2,tmp[0][1]+dp[s][0][1]+4);
dp[x][1][0]=min(min(tmp[0][1]+dp[s][1][1]+1,tmp[0][0]+dp[s][1][0]+3),min(tmp[1][1]+dp[s][0][0]+2,tmp[1][0]+dp[s][0][1]+4));
dp[x][1][1]=min(min(tmp[0][0]+dp[s][1][1]+1,tmp[0][1]+dp[s][1][0]+3),min(tmp[1][0]+dp[s][0][0]+2,tmp[1][1]+dp[s][0][1]+4));
dp[x][2][0]=min(dp[x][2][0],min(tmp[2][1]+dp[s][0][0]+2,tmp[2][0]+dp[s][0][1]+4));
dp[x][2][1]=min(dp[x][2][1],min(tmp[2][0]+dp[s][0][0]+2,tmp[2][1]+dp[s][0][1]+4));
dp[x][2][0]=min(dp[x][2][0],min(tmp[0][1]+dp[s][2][0]+2,tmp[0][0]+dp[s][2][1]+4));
dp[x][2][1]=min(dp[x][2][1],min(tmp[0][0]+dp[s][2][0]+2,tmp[0][1]+dp[s][2][1]+4));
dp[x][2][0]=min(dp[x][2][0],min(tmp[1][0]+dp[s][1][1],tmp[1][1]+dp[s][1][0]+2));
dp[x][2][1]=min(dp[x][2][1],min(tmp[1][1]+dp[s][1][1],tmp[1][0]+dp[s][1][0]+2));
}
dp[x][1][0]=min(dp[x][1][0],dp[x][0][1]+1);
dp[x][1][1]=min(dp[x][1][1],dp[x][0][0]+1);
dp[x][2][0]=min(dp[x][2][0],dp[x][1][0]);
dp[x][2][1]=min(dp[x][2][1],dp[x][1][1]);
}
void solve(){
int n;cin>>n;
rep(i,1,n){
char ch;cin>>ch;
c[i]=ch-'0';
}
memset(dp,0x3f,sizeof(dp));
rep(i,2,n){
int u,v;cin>>u>>v;
e[u].pb(v);e[v].pb(u);
}
rep(i,1,n) if(c[i]==0){
pdfs(i,0);dfs(i,0);
cout<<dp[i][2][1]<<'\n';
return;
}
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// int T;for(cin>>T;T--;)
solve();
return 0;
}