树的难题

树的难题

题意

给出一个无根树。树有 \(N\) 个点,边有权值。每个点都有颜色,是黑色、白色、 灰色这三种颜色之一,称为一棵三色树。 可爱的 Alice 觉得,一个三色树为均衡的,当且仅当,树中不含有黑色结点 或者含有至多一个白色节点。然而,给出的三色树可能并不满足这个性质。 所以,Alice 打算删去若干条边使得形成的森林中每棵树都是均衡的,花费的代价等于删去的边的权值之和。请你计算需要花费的代价最小是多少。

思路

树形 dp,定义 \(dp_{i,0/1/2/3/4}\)

\(0\) 表示子树内没有黑色也没有白色。

\(1\) 表示子树内没有黑色有一个白色。

\(2\) 表示子树内没有黑色有多个白色。

\(3\) 表示子树内有黑色没有白色。

\(4\) 表示子数内有黑色有一个白色。

分类讨论转移和边是否需要被删。

代码

#include <bits/stdc++.h>
#define ll long long
using namespace std;
// 0:无黑无白
// 1:无黑一白
// 2:无黑多白
// 3:有黑无白
// 4:有黑一白 
// 0:黑 1:白 2:灰
const int N = 3e5 + 5;
int n, tot, a[N]; 
ll dp[N][2][5];
vector <pair <int,int>> E[N];
void dfs(int x, int fa) {
	if (a[x] == 0) dp[x][0][0] = dp[x][0][1] = dp[x][0][2] = dp[x][1][0] = dp[x][1][1] = dp[x][1][2] = 1e18;
	if (a[x] == 1) dp[x][0][0] = dp[x][0][3] = dp[x][1][0] = dp[x][1][3] = 1e18;
	for (auto e : E[x]) {
		int y = e.first, z = e.second;
		if (y == fa) continue;
		dfs(y, x);
		if (a[x] == 2) dp[x][1][0] = dp[x][0][0] + min({dp[y][0][0], dp[y][0][1] + z, dp[y][0][2] + z, dp[y][0][3] + z, dp[y][0][4] + z});
		if (a[x] != 0) dp[x][1][1] = min({dp[x][0][1] + dp[y][0][0], dp[x][0][0] + dp[y][0][1], dp[x][0][1] + dp[y][0][2] + z, dp[x][0][1] + dp[y][0][3] + z, dp[x][0][1] + dp[y][0][4] + z});
		if (a[x] != 0) dp[x][1][2] = min({dp[x][0][2] + dp[y][0][0], dp[x][0][2] + dp[y][0][1], dp[x][0][2] + dp[y][0][2], dp[x][0][1] + dp[y][0][1], dp[x][0][1] + dp[y][0][2], dp[x][0][0] + dp[y][0][2], dp[x][0][2] + dp[y][0][3] + z, dp[x][0][2] + dp[y][0][4] + z});
		if (a[x] != 1) dp[x][1][3] = min({dp[x][0][3] + dp[y][0][0], dp[x][0][3] + dp[y][0][3], dp[x][0][0] + dp[y][0][3], dp[x][0][3] + dp[y][0][1] + z, dp[x][0][3] + dp[y][0][2] + z, dp[x][0][3] + dp[y][0][4] + z});
 		dp[x][1][4] = min({dp[x][0][4] + dp[y][0][0], dp[x][0][3] + dp[y][0][1], dp[x][0][3] + dp[y][0][4], dp[x][0][4] + dp[y][0][3], dp[x][0][4] + dp[y][0][2] + z, dp[x][0][4] + dp[y][0][4] + z, dp[x][0][4] + dp[y][0][1] + z});
		swap(dp[x][0], dp[x][1]);
	}
}
void solve() {
	cin >> n;
	for (int i = 1; i <= n; i ++) cin >> a[i];
	for (int i = 1; i <= n; i ++) E[i].clear();
	for (int i = 1, u, v, w; i < n; i ++) {
		cin >> u >> v >> w;
		E[u].push_back({v, w});
		E[v].push_back({u, w});
	} 
	memset(dp, 0, sizeof(dp));
	dfs(1, 0);
	cout << min({dp[1][0][0], dp[1][0][1], dp[1][0][2], dp[1][0][3], dp[1][0][4]}) << "\n";
}
int main() {
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	int T;
	cin >> T;
	while (T --)
		solve();	 
	return 0;
}
posted @ 2024-09-06 10:59  maniubi  阅读(4)  评论(0编辑  收藏  举报