树的难题
树的难题
题意
给出一个无根树。树有 \(N\) 个点,边有权值。每个点都有颜色,是黑色、白色、 灰色这三种颜色之一,称为一棵三色树。 可爱的 Alice 觉得,一个三色树为均衡的,当且仅当,树中不含有黑色结点 或者含有至多一个白色节点。然而,给出的三色树可能并不满足这个性质。 所以,Alice 打算删去若干条边使得形成的森林中每棵树都是均衡的,花费的代价等于删去的边的权值之和。请你计算需要花费的代价最小是多少。
思路
树形 dp,定义 \(dp_{i,0/1/2/3/4}\)。
\(0\) 表示子树内没有黑色也没有白色。
\(1\) 表示子树内没有黑色有一个白色。
\(2\) 表示子树内没有黑色有多个白色。
\(3\) 表示子树内有黑色没有白色。
\(4\) 表示子数内有黑色有一个白色。
分类讨论转移和边是否需要被删。
代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
// 0:无黑无白
// 1:无黑一白
// 2:无黑多白
// 3:有黑无白
// 4:有黑一白
// 0:黑 1:白 2:灰
const int N = 3e5 + 5;
int n, tot, a[N];
ll dp[N][2][5];
vector <pair <int,int>> E[N];
void dfs(int x, int fa) {
if (a[x] == 0) dp[x][0][0] = dp[x][0][1] = dp[x][0][2] = dp[x][1][0] = dp[x][1][1] = dp[x][1][2] = 1e18;
if (a[x] == 1) dp[x][0][0] = dp[x][0][3] = dp[x][1][0] = dp[x][1][3] = 1e18;
for (auto e : E[x]) {
int y = e.first, z = e.second;
if (y == fa) continue;
dfs(y, x);
if (a[x] == 2) dp[x][1][0] = dp[x][0][0] + min({dp[y][0][0], dp[y][0][1] + z, dp[y][0][2] + z, dp[y][0][3] + z, dp[y][0][4] + z});
if (a[x] != 0) dp[x][1][1] = min({dp[x][0][1] + dp[y][0][0], dp[x][0][0] + dp[y][0][1], dp[x][0][1] + dp[y][0][2] + z, dp[x][0][1] + dp[y][0][3] + z, dp[x][0][1] + dp[y][0][4] + z});
if (a[x] != 0) dp[x][1][2] = min({dp[x][0][2] + dp[y][0][0], dp[x][0][2] + dp[y][0][1], dp[x][0][2] + dp[y][0][2], dp[x][0][1] + dp[y][0][1], dp[x][0][1] + dp[y][0][2], dp[x][0][0] + dp[y][0][2], dp[x][0][2] + dp[y][0][3] + z, dp[x][0][2] + dp[y][0][4] + z});
if (a[x] != 1) dp[x][1][3] = min({dp[x][0][3] + dp[y][0][0], dp[x][0][3] + dp[y][0][3], dp[x][0][0] + dp[y][0][3], dp[x][0][3] + dp[y][0][1] + z, dp[x][0][3] + dp[y][0][2] + z, dp[x][0][3] + dp[y][0][4] + z});
dp[x][1][4] = min({dp[x][0][4] + dp[y][0][0], dp[x][0][3] + dp[y][0][1], dp[x][0][3] + dp[y][0][4], dp[x][0][4] + dp[y][0][3], dp[x][0][4] + dp[y][0][2] + z, dp[x][0][4] + dp[y][0][4] + z, dp[x][0][4] + dp[y][0][1] + z});
swap(dp[x][0], dp[x][1]);
}
}
void solve() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> a[i];
for (int i = 1; i <= n; i ++) E[i].clear();
for (int i = 1, u, v, w; i < n; i ++) {
cin >> u >> v >> w;
E[u].push_back({v, w});
E[v].push_back({u, w});
}
memset(dp, 0, sizeof(dp));
dfs(1, 0);
cout << min({dp[1][0][0], dp[1][0][1], dp[1][0][2], dp[1][0][3], dp[1][0][4]}) << "\n";
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int T;
cin >> T;
while (T --)
solve();
return 0;
}
本文来自博客园,作者:maniubi,转载请注明原文链接:https://www.cnblogs.com/maniubi/p/18399839,orz