洛谷 P6326 - Shopping(点分治+树上背包)
真的是好久没写过题解了,先水一篇再说(
首先看到树上连通块问题,一眼树形 DP,然后发现需要背包合并,\(\mathcal O(nm^2)\),寄。
我们冷静一下,发现对于背包这类结构,合并的复杂度高达容量的平方,但单点插入的复杂度却不是太高(如果使用二进制拆分 / 单调队列,则复杂度则是 \(\Theta(m\log D)\) / \(\Theta(m)\)),这就启示我们使用插入 instead of 合并。有什么结构支持合并呢?考虑点分治,在点分治过程中,我们不妨假设分治中心必选,那么我们相当于要找一个权值最大的树上连通块满足其包含根。对于这一类问题,我们考虑一个经典的“父亲传给儿子,在子树里绕一圈再传回父亲”的套路,即我们考虑从根节点出发,当我们 DFS 到 \(x\) 时,我们遍历其所有子节点 \(y\),然后将 \(x\) 的 DP 值赋给 \(y\),然后在 \(y\) 子树里扫一遍后再令 \(x\) 的 DP 值对 \(y\) 的 DP 值取 \(\max\),这就是所谓的“树上连通块套路”。
下面的代码使用二进制拆分实现,时间复杂度 \(nm\log n\log D\),当然也可以使用单调队列,时间复杂度 \(nm\log n\)。
const int MAXN = 500;
const int MAXM = 4000;
const int INF = 0x3f3f3f3f;
int n, m, w[MAXN + 5], c[MAXN + 5], d[MAXN + 5];
int hd[MAXN + 5], to[MAXN * 2 + 5], nxt[MAXN * 2 + 5], ec = 0;
void adde(int u, int v) {to[++ec] = v; nxt[ec] = hd[u]; hd[u] = ec;}
int siz[MAXN + 5], mx[MAXN + 5], cent; bool vis[MAXN + 5];
int dp[MAXN + 5][MAXM + 5], res;
void clear() {
memset(hd, 0, sizeof(hd)); ec = 0; memset(vis, 0, sizeof(vis));
mx[0] = INF; cent = res = 0; memset(dp, 0xcf, sizeof(dp));
}
void findcent(int x, int f, int tot) {
siz[x] = 1; mx[x] = 0;
for (int e = hd[x]; e; e = nxt[e]) {
int y = to[e]; if (y == f || vis[y]) continue;
findcent(y, x, tot); siz[x] += siz[y]; chkmax(mx[x], siz[y]);
}
chkmax(mx[x], tot - siz[x]);
if (mx[x] < mx[cent]) cent = x;
}
vector<int> pt;
void findpts(int x, int f) {
pt.pb(x);
for (int e = hd[x]; e; e = nxt[e]) {
int y = to[e]; if (y == f || vis[y]) continue;
findpts(y, x);
}
}
void ins(int x, int num, int cst, int val) {
if (!num) return;
int sum = 0, cur = 1;
while (sum + cur <= num) {
for (int i = m; i >= cst * cur; i--) chkmax(dp[x][i], dp[x][i - cst * cur] + val * cur);
sum += cur; cur <<= 1;
}
for (int i = m; i >= cst * (num - sum); i--) chkmax(dp[x][i], dp[x][i - cst * (num - sum)] + val * (num - sum));
}
void dfs(int x, int f) {
for (int e = hd[x]; e; e = nxt[e]) {
int y = to[e]; if (y == f || vis[y]) continue;
for (int i = c[y]; i <= m; i++) dp[y][i] = dp[x][i - c[y]] + w[y];
ins(y, d[y] - 1, c[y], w[y]); dfs(y, x);
for (int i = 0; i <= m; i++) chkmax(dp[x][i], dp[y][i]);
}
for (int i = 0; i <= m; i++) chkmax(res, dp[x][i]);
}
void divcent(int x) {
// printf("divcent %d\n", x);
pt.clear(); findpts(x, 0); vis[x] = 1;
for (int y : pt) memset(dp[y], 0xcf, sizeof(dp[y]));
dp[x][c[x]] = w[x]; ins(x, d[x] - 1, c[x], w[x]); dfs(x, 0);
for (int e = hd[x]; e; e = nxt[e]) {
int y = to[e]; if (vis[y]) continue; cent = 0;
findcent(y, x, siz[y]); divcent(cent);
}
}
void solve() {
scanf("%d%d", &n, &m); clear();
for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
for (int i = 1, u, v; i < n; i++) scanf("%d%d", &u, &v), adde(u, v), adde(v, u);
findcent(1, 0, n); divcent(cent);
printf("%d\n", res);
}
int main() {
int qu; scanf("%d", &qu);
while (qu--) solve();
return 0;
}