下笔春蚕食叶声。

LGP6326 Shopping 点分治+dp

题意

一个 \(n\) 个点的树,每个点上有 \(d_i\) 个物品,每件体积为 \(c_i\),价值是 \(w_i\)

在树上选择一些物品(每个点可以选多个),使得他们可以组成一个连通块,求能获得的最大价值。

题解

考虑直接树形dp。设 \(dp_{i,j}\) 为在 \(i\) 子树内使用 \(j\) 的体积,并且选了 \(i\) ,能够得到的最大价值。时间复杂度是 \(O(nm^2)\) 的。

瓶颈在于背包是 \(O(m^2)\) 的。

那么我们考虑如何在可接受的复杂度内计算 \(dp_{i,j}\)


\(dp_{i,j}\) 表示的是 \(i\) 子树内一个包括了 \(i\) 的连通块,其实就是要求选某个物品的话,必须至少选一个他父亲节点的物品

一类有依赖的树形背包的dp方法

而这题是多重背包,二进制分组/单调队列一下,求解一次 \(dp_{i,j}\),复杂度会是 \(O(nm\log d)/O(nm)\)

考虑使用点分治,就会变成 \(O(nm\log n\log d)/O(nm\log n)\)

下面是二进制分组版本,因为不想写单调队列,也不太会写。。。

很容易写错,,而且跑得比下面那种慢多了。。。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mkp make_pair
#define pb push_back
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define ls(x) ((x) << 1)
#define rs(x) ((x) << 1 | 1)
#define fi first
#define se second
const int N = 510, M = 4010, inf = 0x3f3f3f3f;
int n, m, siz, rt, ans, c[N], d[N], w[N], id[N];
int e, to[N << 1], nxt[N << 1], hd[N];
bool vis[N]; int sz[N], mxsz[N], tim, dp[N][M], f[M];
void add(int u, int v) {
	to[++e] = v; nxt[e] = hd[u]; hd[u] = e;
}
void findrt(int u, int fa) {
	sz[u] = 1; mxsz[u] = 0; 
	for(int i = hd[u]; i; i = nxt[i]) {
		int v = to[i]; if(v == fa || vis[v]) continue;
		findrt(v, u); sz[u] += sz[v];
		if(sz[v] > mxsz[u]) mxsz[u] = sz[v];
	}
	mxsz[u] = max(mxsz[u], siz - sz[u]);
	if(mxsz[u] < mxsz[rt]) rt = u;
}
void init() {
	ans = 0; mxsz[0] = inf;
	for(int i = 1; i <= n; i++)
		hd[i] = vis[i] = 0;
	for(int i = 1; i <= e; i++)
		to[i] = nxt[i] = 0;
	e = 0;
}
void dfs(int u, int fa) {
	sz[u] = 1; id[++tim] = u;
	for(int i= hd[u]; i; i = nxt[i]) {
		int v = to[i]; if(vis[v] || v == fa) continue;
		dfs(v, u); sz[u] += sz[v];
	}
	return;
}
void calc(int u) {
	tim = 0; dfs(u, 0);
	dp[tim + 1][0] = 0; for(int i = 1; i <= m; i++) dp[tim + 1][i] = -inf;
	for(int i = tim; i >= 1; i--) {
		int p = id[i], t = d[p] - 1;
		for(int k = m; k >= 0; k--)
			f[k] = (k >= c[p]) ? (dp[i + 1][k - c[p]] + w[p]) : -inf;
		for(int j = 1; j <= t; t -= j, j <<= 1) {
			for(int k = m; k >= j * c[p]; k--)
				f[k] = max(f[k], f[k - j * c[p]] + j * w[p]);
		}
		if(t) {
			for(int k = m; k >= t * c[p]; k--)
				f[k] = max(f[k], f[k - t * c[p]] + t * w[p]);
		}
		for(int k = 0; k <= m; k++)
			dp[i][k] = max(dp[i + sz[p]][k], f[k]);
	}
	for(int i = 0; i <= m; i++)
		ans = max(ans, dp[1][i]);
}
void solve(int u) {
	calc(u);
	vis[u] = 1;
	for(int i = hd[u]; i; i = nxt[i]) {
		int v = to[i]; if(vis[v]) continue;
		siz = sz[v]; rt = 0; findrt(v, 0);
		solve(rt);
	}
	return;
}
int main(){
	int T; scanf("%d", &T);
	while(T--) {
		init();
		scanf("%d%d", &n, &m);
		for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
		for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
		for(int i = 1; i <= n; i++) scanf("%d", &d[i]);
		for(int i = 1, x, y; i < n; i++)
			scanf("%d%d", &x, &y), add(x, y), add(y, x);
		rt = 0; siz = n; findrt(1, 0);
		solve(rt);
		printf("%d\n", ans);
	}
	return 0;
}

还有一种dp的方法,可以在dfs时就直接dp。

\(dp_{i,j}\) 改为表示dfs完 \(i\) 点及其子树,此时所访问过的所有点在用了 \(j\) 的体积的情况下的最大价值。

\(f_{j}\) 表示现在dfs到现在的点 \(i\),且选了 \(u\),所访问过的所有点的在用了 \(j\) 的体积的情况下的最大价值。计算后变成 和 \(dp_{i,j}\) 一样的定义。

那么这个 \(f_j\) 往下dfs下去的时候,一定会强制选当前点的父亲,否则就不可能出现当前点了,然后再多重背包一下就行了。

我说的巨大多不清楚,其实代码很清楚。。。但我觉得实际上实现起来复杂的一批(恼)

其实本质上和前一种方法是差不多的,状态都是当前dfs已经访问的所有点/dfs序中还没访问的那些点

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mkp make_pair
#define pb push_back
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define ls(x) ((x) << 1)
#define rs(x) ((x) << 1 | 1)
#define fi first
#define se second
const int N = 510, M = 4010, inf = 0x3f3f3f3f;
int n, m, siz, rt, ans, c[N], d[N], w[N];
int e, to[N << 1], nxt[N << 1], hd[N];
bool vis[N]; int sz[N], mxsz[N], f[M], dp[N][M];
void add(int u, int v) {
	to[++e] = v; nxt[e] = hd[u]; hd[u] = e;
}
void findrt(int u, int fa) {
	sz[u] = 1; mxsz[u] = 0; 
	for(int i = hd[u]; i; i = nxt[i]) {
		int v = to[i]; if(v == fa || vis[v]) continue;
		findrt(v, u); sz[u] += sz[v];
		if(sz[v] > mxsz[u]) mxsz[u] = sz[v];
	}
	mxsz[u] = max(mxsz[u], siz - sz[u]);
	if(mxsz[u] < mxsz[rt]) rt = u;
}
void init() {
	ans = 0; mxsz[0] = inf;
	for(int i = 1; i <= n; i++)
		hd[i] = vis[i] = 0;
	for(int i = 1; i <= e; i++)
		to[i] = nxt[i] = 0;
	e = 0;
}
void dfs(int u, int fa, int cst) {
	if(cst > m) return;
	for(int i = 0; i <= m; i++)
		dp[u][i] = f[i];
		
	for(int i = m; i >= 0; i--) {
		if(i >= cst + c[u])
			f[i] = f[i - c[u]] + w[u];
		else f[i] = -inf;
	}
	int t = d[u] - 1;
	for(int i = 1; i <= t; t -= i, i <<= 1)
		for(int j = m; j >= i * c[u]; j--)
			f[j] = max(f[j], f[j - i * c[u]] + i * w[u]);
	if(t) {
		for(int j = m; j >= t * c[u]; j--)
			f[j] = max(f[j], f[j - t * c[u]] + t * w[u]);
	}
	
	for(int i= hd[u]; i; i = nxt[i]) {
		int v = to[i]; if(vis[v] || v == fa) continue;
		dfs(v, u, cst + c[u]);
	}
	
	for(int i = 0; i <= m; i++)
		dp[u][i] = f[i] = max(dp[u][i], f[i]);
	return;
}
void calc(int u) {
	f[0] = 0; for(int i = 1; i <= m; i++) f[i] = -inf;
	dfs(u, 0, 0);
	for(int i = 0; i <= m; i++)
		ans = max(ans, dp[u][i]);
}
void solve(int u) {
	calc(u);
	vis[u] = 1;
	for(int i = hd[u]; i; i = nxt[i]) {
		int v = to[i]; if(vis[v]) continue;
		siz = sz[v]; rt = 0; findrt(v, 0);
		solve(rt);
	}
	return;
}
int main(){
	int T; scanf("%d", &T);
	while(T--) {
		init();
		scanf("%d%d", &n, &m);
		for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
		for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
		for(int i = 1; i <= n; i++) scanf("%d", &d[i]);
		for(int i = 1, x, y; i < n; i++)
			scanf("%d%d", &x, &y), add(x, y), add(y, x);
		rt = 0; siz = n; findrt(1, 0);
		solve(rt);
		printf("%d\n", ans);
	}
	return 0;
}
posted @ 2022-01-25 20:51  ACwisher  阅读(43)  评论(0编辑  收藏  举报