HDU5293 Tree Chain Problem 题解

给定一棵树和 \(m\) 条路径,每条路径有权值。要求从中选若干条结点不相交的路径使得权值最大。

\(n,m\le 10^5\)


对于树上路径的 DP 问题,常常把路径的贡献/限制放到它的 LCA 处考虑。

\(dp[u]\)\(u\) 的子树内选完全在子树内的路径,结点不相交的最大权值是多少。

\(sum[u]=\sum_{v\in son(u)}dp[v]\)

\(u\) 处不选路径,\(dp[u]\leftarrow sum[u]\)

\(u\) 处选路径,枚举路径 \(p\)\(dp[u]\leftarrow w(p)+\sum_{v\in p}sum[v]-\sum_{v\in p,v\neq u}dp[v]\)

注意到 \(\sum_{v\in p}sum[v]\)\(\sum_{v\in p,v\neq u}dp[v]\) 可以看作树上路径求和的问题,因此需要一个数据结构支持路径求和、单点修改。

使用 dfn 序 + BIT 即可完成。

然而 HDU 卡常。

点击查看代码
#include <bits/stdc++.h>

using namespace std;
const int N = 1e5 + 5, inf = 0x3f3f3f3f;

int n, m;
int fa[N], sz[N];
vector<int> e[N];
struct Path {
	int u, v, w;
} p[N];
vector<int> v[N];

int dfn[N], cur = 0;
int st[20][N], lg[N] = {}, pw[20];
void dfs(int x, int pr) {
	fa[x] = pr;
	sz[x] = 1;
	dfn[x] = ++cur;
	st[0][dfn[x]] = pr;
	for (auto i: e[x])
		if (i != pr) {
			dfs(i, x);
			sz[x] += sz[i];
		}
}
int sml(int x, int y) {
	return dfn[x] < dfn[y] ? x : y;
}
void init_lca() {
	pw[0] = 1;
	for (int i = 1; i < 20; i++)
		pw[i] = pw[i - 1] * 2;
	for (int i = 2; i <= n; i++)
		lg[i] = lg[i / 2] + 1;
	for (int i = 1; i < 20; i++)
		for (int j = 1; j + pw[i] - 1 <= n; j++)
			st[i][j] = sml(st[i - 1][j], st[i - 1][j + pw[i - 1]]);
}
int lca(int u, int v) {
	if (u == v)
		return u;
	u = dfn[u], v = dfn[v];
	if (u > v)
		swap(u, v);
	int s = lg[v - u];
	return sml(st[s][u + 1], st[s][v - pw[s] + 1]);
}

struct BIT {
	int tr[N];
	void clear() {
		for (int i = 1; i <= n; i++)
			tr[i] = 0;
	} 
	int lowbit(int x) {
		return x & -x;
	} 
	void mdf(int x, int v) {
		for (int i = x; i <= n; i += lowbit(i))
			tr[i] += v;
	}
	int qry(int x) {
		int ret = 0;
		for (int i = x; i; i -= lowbit(i))
			ret += tr[i];
		return ret;
	}
} t1, t2; 
int dp[N], sum[N];

int qry(BIT &t, int u, int v) {
	return t.qry(dfn[u]) + t.qry(dfn[v]) - t.qry(dfn[lca(u, v)])
		 - (fa[lca(u, v)] == 0 ? 0 : t.qry(dfn[fa[lca(u, v)]]));
}

void srh(int x, int pr) {
	dp[x] = sum[x] = 0;
	for (auto i: e[x])
		if (i != pr) {
			srh(i, x);
			sum[x] += dp[i];
		}
	dp[x] = sum[x];
	t2.mdf(dfn[x], +sum[x]);
	t2.mdf(dfn[x] + sz[x], -sum[x]);
	for (auto i: v[x]) {
		dp[x] = max(dp[x], p[i].w + qry(t2, p[i].u, p[i].v) - qry(t1, p[i].u, p[i].v));
	}
	t1.mdf(dfn[x], +dp[x]);
	t1.mdf(dfn[x] + sz[x], -dp[x]);
}

void slv() {
	scanf("%d%d", &n, &m);
	cur = 0;
	t1.clear();
	t2.clear();
	for (int i = 1; i <= n; i++) {
		e[i].clear();
		v[i].clear();
		dp[i] = sum[i] = 0;
	}
	for (int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		e[u].push_back(v);
		e[v].push_back(u);
	}
	dfs(1, 0);
	init_lca();
	
	for (int i = 1; i <= m; i++) {
		scanf("%d%d%d", &p[i].u, &p[i].v, &p[i].w);
		v[lca(p[i].u, p[i].v)].push_back(i);
	}
	srh(1, 0);
	printf("%d\n", dp[1]);
}

int main() {
	int T;
	cin >> T;
	while (T--)
		slv();
    return 0;
}
posted @ 2024-11-18 21:12  FLY_lai  阅读(2)  评论(0编辑  收藏  举报