题解 QOJ4815【Flower's Land】

Flower's Land - Problem - QOJ.ac

题目描述

给一棵 \(n\) 个点、点带权的树和常数 \(k\leq n\),对每个点求包含它的大小为 \(k\) 的连通块的最大权值和。\(n\leq 40000, k\leq 3000\)

树上背包

树上背包,选一个连通块大小为 \(k\),问总权值最大?

  1. 非常经典的 \(f_{u,i}\) 表示 \(u\) 子树选 \(i\) 个,全局 \(O(n^2)\)。但是换根复杂度假。
  2. 按照 dfn 序背包,\(f_{i,j}\) 表示考虑了 \([i,n]\) 的 dfn 序的选 \(j\) 个点的背包,转移要么选儿子,要么跳过整棵子树。
  3. \(f_{i,j}=\max\{f_{i+1,j-1}+a_i,f_{i+siz_u,j}\}\)

solution

点分治,设分治重心 \(root\)。欲将计算包含点 \(u\)\(root\) 的最大权值连通块,将答案的连通块的点分为四类:

  • \(u\to root\) 强制选;
  • \(u\) 的子树(可以和下面任意一类合并,注意这部分不要重复计算);
  • 先序遍历 \(\geq dfn_u+siz_u\) 的点;
  • 后序遍历 \(\geq dfn_u+siz_u\) 的点(注意不是先序遍历的翻转)。

用刚才说的树形 DP 做,在某一个点上合并答案时暴力求值一个 \((\max,+)\) 的点值,一共 \(O(nk\log n)\)。如果点分治时 \(siz_u<k\),跳过,那么复杂度为 \(O(nk\log\frac{n}{k})\)。(一直做除法直到 \(<k\),需要至多 \(\log\frac{n}{k}\) 次)

code

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
int n, k, a[40010], siz[40010], rnk[40010], cnt;
basic_string<int> g[40010];
bool cut[40010];
int findcen(int rt, int T) {/*{{{*/
  pair<int, int> cen{(int)1e9, 0};
  auto dfs = [&](auto& dfs, int u, int fa) -> void {
    int smx = siz[u] = 1;
    for (int v : g[u]) if (!cut[v] && v != fa) dfs(dfs, v, u), siz[u] += siz[v], smx = max(smx, siz[v]);
    cen = min(cen, {max(smx, T - siz[u]), u});
  };
  return dfs(dfs, rt, 0), cen.second;
}/*}}}*/
int fans[40010];
void dfs(int kd, int u, int fa) {
  rnk[++cnt] = u;
  siz[u] = 1;
  if (kd) reverse(g[u].begin(), g[u].end());
  for (int v : g[u]) if (v != fa && !cut[v]) dfs(kd, v, u), siz[u] += siz[v];
  if (kd) reverse(g[u].begin(), g[u].end());
}
vector<int> f[40010], tmp[40010][2];
void calc(int u, int fa, int sum, int c) {
  if (++c > k) return ;
  sum += a[u];
  const auto& [lhs, rhs] = tmp[u];
#ifdef LOCAL
  debug("calc(%d, sum=%d, c=%d)\n", u, sum, c);
  debug("lhs: "); for (int x : lhs) debug("%d, ", x); debug("\n");
  debug("rhs: "); for (int x : rhs) debug("%d, ", x); debug("\n");
#endif
  for (int i = 0; i <= k - c && i < (int)lhs.size(); i++) {
    if (k - c - i < (int)rhs.size()) fans[u] = max(fans[u], sum + lhs[i] + rhs[k - c - i]);
  }
  for (int v : g[u]) if (v != fa && !cut[v]) calc(v, u, sum, c);
}
void work(int rt) {
  for (int kd : {0, 1}) {
    dfs(kd, rt, cnt = 0);
    if (cnt < k) return ;
    f[cnt + 1] = {0};
    for (int j = cnt; j >= 1; j--) {
      f[j] = f[j + siz[rnk[j]]];
      if (f[j].size() < f[j + 1].size() + 1) f[j].resize(f[j + 1].size() + 1, numeric_limits<int>::min());
      for (size_t k = 0; k < f[j + 1].size(); k++) f[j][k + 1] = max(f[j][k + 1], f[j + 1][k] + a[rnk[j]]);
      if (f[j].size() > k) f[j].resize(k);
    }
    if (!kd) for (int j = cnt; j >= 1; j--) tmp[rnk[j]][kd] = move(f[j + 1]);
    else for (int j = cnt; j >= 1; j--) tmp[rnk[j]][kd] = f[j + siz[rnk[j]]];
  }
  debug("rt = %d\n", rt);
  calc(rt, 0, 0, 0);
}
void solve(int rt) {
  cut[rt] = true;
  work(rt);
  for (int v : g[rt]) if (!cut[v]) solve(findcen(v, siz[v]));
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);  
#endif
  cin >> n >> k;
  for (int i = 1; i <= n; i++) cin >> a[i];
  for (int i = 1, u, v; i < n; i++) cin >> u >> v, g[u] += v, g[v] += u;
  solve(findcen(1, n));
  for (int i = 1; i <= n; i++) cout << fans[i] << " \n"[i == n];
  return 0;
}

posted @ 2024-08-04 11:42  caijianhong  阅读(42)  评论(0编辑  收藏  举报