8.29 最后一英里
题意
给一颗根为\(1\)的有根树,树上每个点的权值为\(w_i\),大小为\(a_i\)
有\(q\)个询问,给出两个参数\(x,s\)
询问在以\(x\)为根的子树中,选出若干个点,这些点的大小之和不超过\(s\),并最大化权值之和
解法
一个明显的\(O(NS^2)\)的树形背包暴力
设\(f[x][k]\)为以\(x\)为根的子树中大小和小于\(k\)的结点的最大权值和,转移也很显然
我们可以发现对于一个点\(x\),它的状态是由所有子节点的状态合并转移过来的
于是可以考虑启发式合并的小\(trick\)
把整棵树进行轻重链剖分,对于一个节点,先把它重儿子的状态复制上去,对于轻儿子暴力转移:具体来说,就是把轻儿子为根的子树中的每个节点视作一个物品,背包即可
有两种证明复杂度是\(O(NSlogN)\)的方法:
第一种:
考虑启发式合并时,每次合并以后子树大小至少变为原来的两倍,所以每个节点最多被合并\(logN\)次。合并单个节点的复杂度是\(O(S)\)的,所以总复杂度为\(O (NSlogN)\)
第二种:
考虑轻重链剖分的性质,对于任意一个节点,其到根节点的路径上有\(logN\)级别个轻路径,因此每个节点最多被合并\(logN\)次。所以复杂度得证
记住对于这类题目的模型:
树上的问题,状态由子节点状态合并转移而来,可以考虑轻重链剖分启发式合并,重儿子状态直接继承,轻儿子状态暴力转移
代码
#include <cstdio>
#include <cctype>
#include <cstring>
using namespace std;
int read();
const int N = 4e5 + 10;
int n, q;
int w[N], a[N];
int cap;
int head[N], to[N << 1], nxt[N << 1];
int sz[N], son[N];
long long f[5010][5010];
inline void add(int x, int y) {
to[++cap] = y, nxt[cap] = head[x], head[x] = cap;
to[++cap] = x, nxt[cap] = head[y], head[y] = cap;
}
inline long long max(long long x, long long y) {
return x > y ? x : y;
}
void DFS3(int x, int fa, long long *f) {
for (int i = 5000; i >= a[x]; --i) f[i] = max(f[i], f[i - a[x]] + w[x]);
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
DFS3(to[i], x, f);
}
}
void DFS2(int x, int fa) {
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
DFS2(to[i], x);
}
memcpy(f[x], f[son[x]], sizeof f[son[x]]);
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa || to[i] == son[x]) continue;
DFS3(to[i], x, f[x]);
}
for (int i = 5000; i >= a[x]; --i) f[x][i] = max(f[x][i], f[x][i - a[x]] + w[x]);
}
void DFS(int x, int fa) {
sz[x]++;
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
DFS(to[i], x);
sz[x] += sz[to[i]];
if (sz[to[i]] > sz[son[x]]) son[x] = to[i];
}
}
int main() {
n = read();
int u, v;
for (int i = 1; i < n; ++i) {
u = read(), v = read();
add(u, v);
}
for (int i = 1; i <= n; ++i)
scanf("%d%d", w + i, a + i);
DFS(1, 0);
DFS2(1, 0);
q = read();
while (q--) {
u = read(), v = read();
printf("%lld\n", f[u][v]);
}
return 0;
}
int read() {
int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
return x;
}