【luogu P8352】小 N 的独立集(DP套DP)(性质)

小 N 的独立集

题目链接:luogu P8352

题目大意

给你一棵树,然后每个点可以有 1~k 的点权。
然后对于每一种点权,形成的 n^k 棵树,最大权独立集的点权和为 x 的树有多少种。
对于每个 x 输出答案。

思路

笑死第一次做 DP 套 DP 推起来贼乱被爆杀。
(知道 DP 套 DP 的不难就知道)

里层 DP:\(f_{i,0/1}\) 当前点是 \(i\)\(i\) 没选 / 选了。
然后对于一个点 \(u\) 的儿子 \(v\) 合并子树:
\(f_{u,0}=\max\{f_{v,1},f_{v,0}\}\)
\(f_{u,1}=f_{v,0}+a_u\)

然后外层就设 \(f_{u,i,j}\)\(u\)\(f_{u,0}\)\(f_{u,1}\) 分别是 \(i,j\) 的方案数。
然后你就枚举一下类似背包的合并即可。
然后发现会 T,\(O(n^4k^4)\)。(合并 \(O(n^3)\)

考虑找点性质优化:
因为每次选了最多加 \(k\),我们可以从 DP 的式子发现如果 \(f_{v,0}\geqslant f_{v,1}\),那一定是只需要 \(f_{v,0}\),那 \(f_{v,1}\) 的值就没有意义了。
而且如果是小于,差不会超过 \(k\),那我们可以换个表示方法:\(f_{u,i,j}\)\(u\)\(f_{u,0},f_{u,1}\) 分别是 \(i,i+j\) 的方案数,那 \(j\) 就是 \(k\) 的级别的了。
然后转移就是 \(O(n^2k^4)\)(因为你按大小合并从 LCA 的角度计数是 \(n^2\) 的)

至于转移细节,枚举 \(i,ii,j,jj\) 分别表示 \(f_{u,i,ii}\)\(f_{v,j,jj}\)
那转移到的就是 \(f_{u,i+j+jj,\max(0,ii-jj)}\)(因为你这里是要 \(i\) 不选儿子 \(j\) 选不选都可以),然后第三维也不难理解。
(第三维可以通过 \(\max(i+ii+j,i+j+jj)-(i+j+jj)\) 另一种方法,似乎更加利于理解)

代码

#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#define ll long long
#define mo 1000000007 

using namespace std;

const int N = 1005;
int n, K, sz[N];
vector <int> G[N];
ll f[N][N * 5][6], tmp[N * 5][6];
//1~5:正常,f1>f0,此时j记录f0
//f0不选,f1选 

void dfs(int now, int father) {
	sz[now] = K;
	for (int i = 1; i <= K; i++) f[now][0][i] = 1;
	for (int to = 0; to < G[now].size(); to++) { int x = G[now][to];
		if (x == father) continue; dfs(x, now);
		for (int i = 0; i <= sz[now]; i++) for (int ii = 0; ii <= K; ii++) if (f[now][i][ii])
			for (int j = 0; j <= sz[x]; j++) for (int jj = 0; jj <= K; jj++) if (f[x][j][jj]) {
				(tmp[i + j + jj][max(ii - jj, 0)] += f[now][i][ii] * f[x][j][jj] % mo) %= mo;
			}
		
		sz[now] += sz[x];
		for (int i = 0; i <= sz[now]; i++) for (int j = 0; j <= K; j++) f[now][i][j] = tmp[i][j], tmp[i][j] = 0; 
	}
}

int main() {
//	freopen("nset.in", "r", stdin);
//	freopen("nset.out", "w", stdout);
	
	scanf("%d %d", &n, &K);
	for (int i = 1; i < n; i++) {
		int x, y; scanf("%d %d", &x, &y);
		G[x].push_back(y); G[y].push_back(x);
	}
	
	dfs(1, 0);
	
	for (int i = 1; i <= K * n; i++) {
		ll re = 0;
		for (int j = 0; j <= i && j <= K; j++)
			(re += f[1][i - j][j]) %= mo;
		printf("%lld\n", re);
	}
	
	return 0;
}
posted @ 2022-07-28 20:32  あおいSakura  阅读(106)  评论(0编辑  收藏  举报