洛谷 P4383 [八省联考 2018] 林克卡特树
原题等价于在树上选出 \(k + 1\) 条不相交链,最大化边权和。
树形 DP。设 \(f_{u, k, 0/1/2}\) 表示在 \(u\) 的子树中选了 \(k\) 条链,\(u\) 的度数为 \(0, 1, 2\) 的最大边权和。
注意到状态里缺了链退化为单个点的情况,可以把它放到 \(f_{u, k, 2}\) 中(相当于自环)。
转移时分讨一下:
- 若选 \((u, v)\)
\[f_{u,k_1,0} + f_{v,k_2,0} + w_{u, v} \to f_{u, k_1 + k_2 + 1, 1} \\
f_{u, k_1, 1} + f_{v, k_2, 0} + w_{u, v} \to f_{u, k_1 + k_2, 2} \\
f_{u, k_1, 1} + f_{v, k_2, 1} + w_{u, v} \to f_{u, k_1 + k_2 - 1, 2}
\]
- 若不选 \((u, v)\)
\[f_{u, k_1, t_1} + f_{v, k_2, t_2} \to f_{u, k_1 + k_2, t_1}
\]
答案即 \(\max\limits_{t \in \{0, 1, 2\}}\{f_{1, k + 1, t}\}\)。时间复杂度 \(\mathcal O(nk)\)。
考虑优化。
注意到后一次的可选集合是前一次的子集,所以前一次的增量一定大于后一次的(形式化地,记 \(F_i\) 表示选出 \(i\) 条不相交链的最大边权和,则有 \(F_i - F_{i-1} > F_{i+1} - F_i\)),否则前后两次选的两条链可以交换顺序,使得前一次的答案更大。即 \((i, F_i)\) 连线形成的图像是一个上凸壳。
于是可以 wqs 二分,把状态里 \(k\) 那维删去,变为 \(f_{u, 0/1/2}\),表示在 \(u\) 的子树中选若干条链,\(u\) 不在链/在链中间/在链端点上的最大边权和,并记录每个状态下至少选了多少条链。
二分每条链的附加值 \(C(C \in N)\),因为前面在尽量少选,所以当选出链的个数 \(\le k + 1\) 时更新答案。时间复杂度 \(\mathcal O(n \log nv)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
constexpr int N = 3e5 + 10;
constexpr ll inf = 4.5e18;
int n, k;
int tot, head[N];
struct Edge {int to, nxt, val;} e[N << 1];
inline void add(int u, int v, int w) {e[++tot] = Edge{v, head[u], w}, head[u] = tot;}
struct DP {
ll f; int cnt;
DP operator+(const DP &rhs) const {return {f + rhs.f, cnt + rhs.cnt};}
bool operator<(const DP &rhs) const {return f != rhs.f ? f < rhs.f : cnt > rhs.cnt;}
} f[N][3], fu[3], fvm, res, O;
bool chk(ll C) {
auto dfs = [&](auto &&self, int u, int fa) -> void {
f[u][0] = {0, 0}, f[u][1] = {-inf, 0}, f[u][2] = {C, 1};
for (int i = head[u], v; v = e[i].to, i; i = e[i].nxt) if (v != fa) {
self(self, v, u);
for (int t : {0, 1, 2}) fu[t] = f[u][t];
fvm = max(max(f[v][0], f[v][1]), f[v][2]);
f[u][0] = fu[0] + fvm;
f[u][1] = max(fu[0] + max(f[v][0] + DP{e[i].val + C, 1}, f[v][1] + DP{e[i].val, 0}), fu[1] + fvm);
f[u][2] = max(fu[1] + max(f[v][0] + DP{e[i].val, 0}, f[v][1] + DP{e[i].val - C, -1}), fu[2] + fvm);
}
};
dfs(dfs, 1, -1);
res = max(max(f[1][0], f[1][1]), f[1][2]);
return res.cnt <= k + 1;
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
cin >> n >> k;
for (int i = 1, u, v, w; i < n; i++) cin >> u >> v >> w, add(u, v, w), add(v, u, w);
ll l = -1e12, r = 1e12, mid, ans;
while (l <= r) {
mid = (l + r) >> 1;
if (chk(mid)) ans = res.f - mid * (k + 1), l = mid + 1;
else r = mid - 1;
}
cout << ans;
return 0;
}