题解 QOJ4815【Flower's Land】
Flower's Land - Problem - QOJ.ac
题目描述
给一棵 \(n\) 个点、点带权的树和常数 \(k\leq n\),对每个点求包含它的大小为 \(k\) 的连通块的最大权值和。\(n\leq 40000, k\leq 3000\)。
树上背包
树上背包,选一个连通块大小为 \(k\),问总权值最大?
- 非常经典的 \(f_{u,i}\) 表示 \(u\) 子树选 \(i\) 个,全局 \(O(n^2)\)。但是换根复杂度假。
- 按照 dfn 序背包,\(f_{i,j}\) 表示考虑了 \([i,n]\) 的 dfn 序的选 \(j\) 个点的背包,转移要么选儿子,要么跳过整棵子树。
- \(f_{i,j}=\max\{f_{i+1,j-1}+a_i,f_{i+siz_u,j}\}\)。
solution
点分治,设分治重心 \(root\)。欲将计算包含点 \(u\) 与 \(root\) 的最大权值连通块,将答案的连通块的点分为四类:
- \(u\to root\) 强制选;
- \(u\) 的子树(可以和下面任意一类合并,注意这部分不要重复计算);
- 先序遍历 \(\geq dfn_u+siz_u\) 的点;
- 后序遍历 \(\geq dfn_u+siz_u\) 的点(注意不是先序遍历的翻转)。
用刚才说的树形 DP 做,在某一个点上合并答案时暴力求值一个 \((\max,+)\) 的点值,一共 \(O(nk\log n)\)。如果点分治时 \(siz_u<k\),跳过,那么复杂度为 \(O(nk\log\frac{n}{k})\)。(一直做除法直到 \(<k\),需要至多 \(\log\frac{n}{k}\) 次)
code
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
int n, k, a[40010], siz[40010], rnk[40010], cnt;
basic_string<int> g[40010];
bool cut[40010];
int findcen(int rt, int T) {/*{{{*/
pair<int, int> cen{(int)1e9, 0};
auto dfs = [&](auto& dfs, int u, int fa) -> void {
int smx = siz[u] = 1;
for (int v : g[u]) if (!cut[v] && v != fa) dfs(dfs, v, u), siz[u] += siz[v], smx = max(smx, siz[v]);
cen = min(cen, {max(smx, T - siz[u]), u});
};
return dfs(dfs, rt, 0), cen.second;
}/*}}}*/
int fans[40010];
void dfs(int kd, int u, int fa) {
rnk[++cnt] = u;
siz[u] = 1;
if (kd) reverse(g[u].begin(), g[u].end());
for (int v : g[u]) if (v != fa && !cut[v]) dfs(kd, v, u), siz[u] += siz[v];
if (kd) reverse(g[u].begin(), g[u].end());
}
vector<int> f[40010], tmp[40010][2];
void calc(int u, int fa, int sum, int c) {
if (++c > k) return ;
sum += a[u];
const auto& [lhs, rhs] = tmp[u];
#ifdef LOCAL
debug("calc(%d, sum=%d, c=%d)\n", u, sum, c);
debug("lhs: "); for (int x : lhs) debug("%d, ", x); debug("\n");
debug("rhs: "); for (int x : rhs) debug("%d, ", x); debug("\n");
#endif
for (int i = 0; i <= k - c && i < (int)lhs.size(); i++) {
if (k - c - i < (int)rhs.size()) fans[u] = max(fans[u], sum + lhs[i] + rhs[k - c - i]);
}
for (int v : g[u]) if (v != fa && !cut[v]) calc(v, u, sum, c);
}
void work(int rt) {
for (int kd : {0, 1}) {
dfs(kd, rt, cnt = 0);
if (cnt < k) return ;
f[cnt + 1] = {0};
for (int j = cnt; j >= 1; j--) {
f[j] = f[j + siz[rnk[j]]];
if (f[j].size() < f[j + 1].size() + 1) f[j].resize(f[j + 1].size() + 1, numeric_limits<int>::min());
for (size_t k = 0; k < f[j + 1].size(); k++) f[j][k + 1] = max(f[j][k + 1], f[j + 1][k] + a[rnk[j]]);
if (f[j].size() > k) f[j].resize(k);
}
if (!kd) for (int j = cnt; j >= 1; j--) tmp[rnk[j]][kd] = move(f[j + 1]);
else for (int j = cnt; j >= 1; j--) tmp[rnk[j]][kd] = f[j + siz[rnk[j]]];
}
debug("rt = %d\n", rt);
calc(rt, 0, 0, 0);
}
void solve(int rt) {
cut[rt] = true;
work(rt);
for (int v : g[rt]) if (!cut[v]) solve(findcen(v, siz[v]));
}
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
cin >> n >> k;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1, u, v; i < n; i++) cin >> u >> v, g[u] += v, g[v] += u;
solve(findcen(1, n));
for (int i = 1; i <= n; i++) cout << fans[i] << " \n"[i == n];
return 0;
}
本文来自博客园,作者:caijianhong,转载请注明原文链接:https://www.cnblogs.com/caijianhong/p/18341582/solution-QOJ4815