[JSOI2018]潜入行动 (树形背包)

题目链接

题意:

外星人的母舰可以看成是一棵 n 个节点、 n−1 条边的无向树,树上的节点用 1,2,⋯,n 编号。JYY 的特工已经装备了隐形模块,可以在外星人母舰中不受限制地活动,可以神不知鬼不觉地在节点上安装监听设备。
如果在节点 u 上安装监听设备,则 JYY 能够监听与 u 直接相邻所有的节点的通信。换言之,如果在节点 u 安装监听设备,则对于树中每一条边 (u,v) ,节点 v 都会被监听。
特别注意放置在节点 u 的监听设备并不监听 u 本身的通信,这是 JYY 特别为了防止外星人察觉部署的战术。
JYY 的特工一共携带了 k 个监听设备,现在 JYY 想知道,有多少种不同的放置监听设备的方法,能够使得母舰上所有节点的通信都被监听?为了避免浪费,每个节点至多只能安装一个监听设备,且监听设备必须被用完。
\(n\leq 100000 ,k\leq 100\)


显然是树形背包DP。
但是,状态比较难设计。如果u没有被监视,则u的子节点必须至少有一个选。所以要加一维表示选不选。
而如果u被监视了,则u的子节点可以都不选。所以要加一维表示u是否被监视。
这样就好理解了。

f1[a+b][0]=(f1[a+b][0]+1ll*x0[a][0]*dp[v[i]][b][0][0])%md;
f1[a+b][1]=(f1[a+b][1]+1ll*x0[a][0]*dp[v[i]][b][0][1]+1ll*x0[a][1]*(dp[v[i]][b][0][0]+dp[v[i]][b][0][1]))%md;
f2[a+b][0]=(f2[a+b][0]+1ll*x1[a][0]*dp[v[i]][b][1][0])%md;
f2[a+b][1]=(f2[a+b][1]+1ll*x1[a][0]*dp[v[i]][b][1][1]+1ll*x1[a][1]*(dp[v[i]][b][1][1]+dp[v[i]][b][1][0]))%md;

关键是复杂度。

首先,常规树形背包是\(O(n^2)\)的。
就是每对点会在lca处贡献复杂度。
但是,这个算法,最初觉得是\(O(nk^2)\)的,实际上是\(O(nk)\)的。
证明:

  1. 根据正常树形背包的复杂度\(O(n^2)\),小于等于k的最多产生\(n/k*k^2\)的复杂度。
  2. 大于k与大于k的合并一次,被合并的就增加k,最多n/k次,最多产生\(n/k*k^2\)的复杂度。
  3. 大于k的与小于等于k的合并时,每个小于等于k的最多被合并一次,所以是\(n*s_1+n*s_2+...+n*s_m\),也是\(nk\)

还有一种理解,不知道对不对:
把树按照dfs序变为序列。
然后,在子树中枚举取x个,可以理解为取dfs序的前(后)x个。
而合并时,认为一棵子树取后x个,另一棵取前y个。\((x+y\leq k)\)。这可以合并为长x+y的区间。
这其实就是长度不大于k的子串,最多有nk个。
但是,因为有取0个的情况,所以实际做题时,大约有2的常数。但那个常数就忽略了可以。

代码

#include <stdio.h> 
#define min(a, b)(a < b ? a: b)
#define md 1000000007 
inline int read() {
	char ch;
	while ((ch = getchar()) < '0' || ch > '9');
	int rt = (ch ^ 48);
	while ((ch = getchar()) >= '0' && ch <= '9') rt = (rt << 3) + (rt << 1) + (ch ^ 48);
	return rt;
}
int dp[100002][102][2][2],f1[102][2],f2[102][2];
int x0[102][2],x1[102][2],sz[100002];
int fr[100002],ne[200002],v[200002],bs = 0,k;
void addb(int a, int b) {
	v[bs] = b;
	ne[bs] = fr[a];
	fr[a] = bs++;
}
void dfs(int u, int fu) {
	int si = 0;
	for (int i = fr[u]; i != -1; i = ne[i]) {
		if (v[i] != fu) {
			dfs(v[i], u);
			si += sz[v[i]];
		}
	}
	for (int i = 0; i <= min(k, si); i++) x0[i][0] = x1[i][0] = 0;
	x0[0][0] = x1[0][0] = 1;
	si = 0;
	for (int i = fr[u]; i != -1; i = ne[i]) {
		if (v[i] == fu) continue;
		int rt = sz[v[i]];
		for (int a = 0; a <= min(si, k); a++) {
			for (int b = 0; b <= min(rt, k - a); b++) {
				f1[a + b][0] = (f1[a + b][0] + 1ll * x0[a][0] * dp[v[i]][b][0][0]) % md;
				f1[a + b][1] = (f1[a + b][1] + 1ll * x0[a][0] * dp[v[i]][b][0][1] + 1ll * x0[a][1] * (dp[v[i]][b][0][0] + dp[v[i]][b][0][1])) % md;
				f2[a + b][0] = (f2[a + b][0] + 1ll * x1[a][0] * dp[v[i]][b][1][0]) % md;
				f2[a + b][1] = (f2[a + b][1] + 1ll * x1[a][0] * dp[v[i]][b][1][1] + 1ll * x1[a][1] * (dp[v[i]][b][1][1] + dp[v[i]][b][1][0])) % md;
			}
		}
		si += rt;
		for (int a = 0; a <= min(si, k); a++) {
			x0[a][0] = f1[a][0];
			x0[a][1] = f1[a][1];
			x1[a][0] = f2[a][0];
			x1[a][1] = f2[a][1];
			f1[a][0] = f1[a][1] = f2[a][0] = f2[a][1] = 0;
		}
	}
	for (int a = 0; a <= min(si, k); a++) {
		dp[u][a][0][0] = x0[a][1];
		dp[u][a][1][0] = (x0[a][0] + x0[a][1]) % md;
	}
	for (int a = 1; a <= min(si + 1, k); a++) {
		dp[u][a][0][1] = x1[a - 1][1];
		dp[u][a][1][1] = (x1[a - 1][0] + x1[a - 1][1]) % md;
	}
	sz[u] = si + 1;
}
int main() {
	int n;
	scanf("%d%d", &n, &k);
	for (int i = 1; i <= n; i++) fr[i] = -1;
	for (int i = 0; i < n - 1; i++) {
		int a,b;
		a = read();
		b = read();
		addb(a, b);
		addb(b, a);
	}
	dfs(1, 0);
	printf("%d", (dp[1][k][0][0] + dp[1][k][0][1]) % md);
	return 0;
}
posted @ 2019-08-16 21:36  lnzwz  阅读(207)  评论(0编辑  收藏  举报