HDU5293 Tree Chain Problem 题解
给定一棵树和 \(m\) 条路径,每条路径有权值。要求从中选若干条结点不相交的路径使得权值最大。
\(n,m\le 10^5\)。
对于树上路径的 DP 问题,常常把路径的贡献/限制放到它的 LCA 处考虑。
令 \(dp[u]\) 为 \(u\) 的子树内选完全在子树内的路径,结点不相交的最大权值是多少。
令 \(sum[u]=\sum_{v\in son(u)}dp[v]\)。
若 \(u\) 处不选路径,\(dp[u]\leftarrow sum[u]\)。
若 \(u\) 处选路径,枚举路径 \(p\),\(dp[u]\leftarrow w(p)+\sum_{v\in p}sum[v]-\sum_{v\in p,v\neq u}dp[v]\)。
注意到 \(\sum_{v\in p}sum[v]\) 和 \(\sum_{v\in p,v\neq u}dp[v]\) 可以看作树上路径求和的问题,因此需要一个数据结构支持路径求和、单点修改。
使用 dfn 序 + BIT 即可完成。
然而 HDU 卡常。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5, inf = 0x3f3f3f3f;
int n, m;
int fa[N], sz[N];
vector<int> e[N];
struct Path {
int u, v, w;
} p[N];
vector<int> v[N];
int dfn[N], cur = 0;
int st[20][N], lg[N] = {}, pw[20];
void dfs(int x, int pr) {
fa[x] = pr;
sz[x] = 1;
dfn[x] = ++cur;
st[0][dfn[x]] = pr;
for (auto i: e[x])
if (i != pr) {
dfs(i, x);
sz[x] += sz[i];
}
}
int sml(int x, int y) {
return dfn[x] < dfn[y] ? x : y;
}
void init_lca() {
pw[0] = 1;
for (int i = 1; i < 20; i++)
pw[i] = pw[i - 1] * 2;
for (int i = 2; i <= n; i++)
lg[i] = lg[i / 2] + 1;
for (int i = 1; i < 20; i++)
for (int j = 1; j + pw[i] - 1 <= n; j++)
st[i][j] = sml(st[i - 1][j], st[i - 1][j + pw[i - 1]]);
}
int lca(int u, int v) {
if (u == v)
return u;
u = dfn[u], v = dfn[v];
if (u > v)
swap(u, v);
int s = lg[v - u];
return sml(st[s][u + 1], st[s][v - pw[s] + 1]);
}
struct BIT {
int tr[N];
void clear() {
for (int i = 1; i <= n; i++)
tr[i] = 0;
}
int lowbit(int x) {
return x & -x;
}
void mdf(int x, int v) {
for (int i = x; i <= n; i += lowbit(i))
tr[i] += v;
}
int qry(int x) {
int ret = 0;
for (int i = x; i; i -= lowbit(i))
ret += tr[i];
return ret;
}
} t1, t2;
int dp[N], sum[N];
int qry(BIT &t, int u, int v) {
return t.qry(dfn[u]) + t.qry(dfn[v]) - t.qry(dfn[lca(u, v)])
- (fa[lca(u, v)] == 0 ? 0 : t.qry(dfn[fa[lca(u, v)]]));
}
void srh(int x, int pr) {
dp[x] = sum[x] = 0;
for (auto i: e[x])
if (i != pr) {
srh(i, x);
sum[x] += dp[i];
}
dp[x] = sum[x];
t2.mdf(dfn[x], +sum[x]);
t2.mdf(dfn[x] + sz[x], -sum[x]);
for (auto i: v[x]) {
dp[x] = max(dp[x], p[i].w + qry(t2, p[i].u, p[i].v) - qry(t1, p[i].u, p[i].v));
}
t1.mdf(dfn[x], +dp[x]);
t1.mdf(dfn[x] + sz[x], -dp[x]);
}
void slv() {
scanf("%d%d", &n, &m);
cur = 0;
t1.clear();
t2.clear();
for (int i = 1; i <= n; i++) {
e[i].clear();
v[i].clear();
dp[i] = sum[i] = 0;
}
for (int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 0);
init_lca();
for (int i = 1; i <= m; i++) {
scanf("%d%d%d", &p[i].u, &p[i].v, &p[i].w);
v[lca(p[i].u, p[i].v)].push_back(i);
}
srh(1, 0);
printf("%d\n", dp[1]);
}
int main() {
int T;
cin >> T;
while (T--)
slv();
return 0;
}