Codeforces 1324F 树形dp+换根
题意
给一颗无根树,其节点被染成黑白两色。对于其任意一个节点 vv,我们要求出某一个包含 vv 的联通块,使得联通块内白色节点 – 黑色节点数量最大。
思路
先用一个简单的树形dp自下而上的求出对于根rt,其最大差值为mx[rt]。
状态转移方程为:mx【x】 = Σy为x的子节点 max(mx【y】,0)+mx【x】。
可以理解为,计算 mx[x] 的时候,如果 x 的孩子节点 y 在答案中与 y 联通可以提供正向的差值贡献,那么就连上,否则丢掉。最终答案 x 的答案存储在 mx[x] 中,复杂度为 O(n)。
然后中序遍历采用换根操作。将根从 x 换到 vv.
无非就是顺序执行 mx[x] -= max(0, mx[y]),mx[y] += max(0, mx[y]),即把 y 从 x 的孩子中拿掉,再把 x 加入 y 的孩子中。
#include <iostream> #include <cmath> #include <cstdio> #include <cstring> #include <string> #include <map> #include <iomanip> #include <algorithm> #include <queue> #include <stack> #include <set> #include <vector> // #include <bits/stdc++.h> #define fastio ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0); #define sp ' ' #define endl '\n' #define FOR(i,a,b) for( int i = a;i <= b;++i) #define bug cout<<"--------------"<<endl #define P pair<int, int> #define fi first #define se second #define pb(x) push_back(x) #define ppb() pop_back() #define mp(a,b) make_pair(a,b) #define ms(v,x) memset(v,x,sizeof(v)) #define rep(i,a,b) for(int i=a;i<=b;i++) #define repd(i,a,b) for(int i=a;i>=b;i--) #define sca3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c)) #define sca2(a,b) scanf("%d %d",&(a),&(b)) #define sca(a) scanf("%d",&(a)); #define sca3ll(a,b,c) scanf("%lld %lld %lld",&(a),&(b),&(c)) #define sca2ll(a,b) scanf("%lld %lld",&(a),&(b)) #define scall(a) scanf("%lld",&(a)); using namespace std; typedef long long ll; ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} ll lcm(ll a,ll b){return a/gcd(a,b)*b;} ll powmod(ll a, ll b, ll mod){ll sum = 1;while (b) {if (b & 1) {sum = (sum * a) % mod;b--;}b /= 2;a = a * a % mod;}return sum;} const double Pi = acos(-1.0); const double epsilon = Pi/180.0; const int maxn = 2e5+10; int col[maxn],ans[maxn],mx[maxn]; std::vector<int> edge[maxn]; int n; void dfs(int x,int fa) { mx[x] = (col[x]==1?1:-1); for(auto y : edge[x]){ if(y == fa) continue; dfs(y,x); if(mx[y] >= 0) mx[x] += mx[y]; } } void changeroot(int x,int fa) { for(auto y : edge[x]){ if(y == fa) continue; int sx = mx[x],sy = mx[y]; if(mx[y] >= 0 ) mx[x] -= mx[y]; if(mx[x] >= 0) mx[y] += mx[x]; ans[y] = mx[y]; changeroot(y,x); mx[x] = sx;mx[y] = sy; } } int main() { // freopen("input.txt", "r", stdin); sca(n); rep(i,1,n) sca(col[i]); rep(i,1,n-1){ int x,y; sca2(x,y); edge[x].pb(y); edge[y].pb(x); } dfs(1,0);ans[1] = mx[1]; changeroot(1,0); rep(i,1,n){ printf("%d ",ans[i] ); } printf("\n"); }