[Kattis]redblacktree(树形依赖背包,DP优化)

Source : NAIPC 2018

题意

有棵树,树上有红点和黑点,要选出一系列没有祖孙关系的节点,满足红点恰好有m个,求方案数。
\(n\le2*10^5\)
\(m\le 1000\)

题解

可以用树形背包解决。
\(f[x][j]\)表示以x为根的树里,恰好选了j个红点的方案数。
转移就大力分配红点个数就行。
时间复杂度为\(O(nm^2)\),超时了。

发现自己树形背包一直写假了,实际上树形背包的复杂度应该为\(O(nm)\)

Solution1

对于每个节点,他的子树大小是有限的,可以证明如果把上限从背包容积m改成\(siz[x]\)就可以把时间复杂度优化到\(O(nm)\)
证明链接

#include <bits/stdc++.h>
#define Mid ((l + r) >> 1)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
const int mod = 1e9 + 7;
int read(){
	char c; int num, f = 1;
	while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
	while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
	return f * num;
}
const int N = 2e5 + 9, M = 1e3 + 9;
int f[N][M], n, m, col[N], tmp[N], siz[N];
vector<int> son[N];
void dfs(int x) {
	f[x][0] = 1; siz[x] += col[x];
	for(auto y : son[x]) {
		dfs(y); siz[x] += siz[y];
		for(int i = 0; i <= m; i++) tmp[i] = 0;
		for(int i = 0; i <= m && i <= siz[x]; i++) if(f[x][i]){
			for(int j = 0;  i + j <= m && j <= siz[y]; j++) if(f[y][j]){
				tmp[i + j] = (tmp[i + j] + 1ll * f[x][i] * f[y][j] % mod) % mod;
			}
		}
		for(int i = 0; i <= m; i++) f[x][i] = tmp[i];
	}
	if(col[x] == 1) f[x][1] = (1ll * f[x][1] + 1) % mod;
	if(col[x] == 0) f[x][0] = (1ll * f[x][0] + 1) % mod;
}
signed main()
{
	n = read(); m = read();
	for(int i = 2; i <= n; i++) 
		son[read()].push_back(i);
	for(int i = 1; i <= m; i++) 
		col[read()] = 1;
	dfs(1);
	for(int i = 0; i <= m; i++)
		printf("%d\n", f[1][i]);
	return 0;
}
/*
f[x][k]表示以x的子树,取了k个红点的方案数
f[x][a] * f[y][b] = f[x][a+b]
O(n*m^2)
*/

Solution2

树形背包还可以按照dfs序处理。
给树上节点按照后序遍历标号之后,顺序循环变成先处理子节点再处理父亲。
\(f[i][j]\)表示i之前的森林,取了j个红点方案数。
转移可以是:
当i不取时,\(f[i][j]+=f[i-1][j]\)
当i取时,他的子树就都不能取,由于子树是连续一段,按照\(siz[x]\)跳过即可。
\(f[i][j] = f[i][j] + f[i - siz[i]][j - 1](col[i]==1)\)
\(f[i][j] = f[i][j] + f[i - siz[i]][j](col[i]==0)\)

#include <bits/stdc++.h>
#define Mid ((l + r) >> 1)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
int read(){
	char c; int num, f = 1;
	while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
	while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
	return f * num;
}
const int mod = 1e9 + 7;
const int N = 2e5 + 1009;
int n, m, f[N][1009], col[N], id[N], siz[N], cnt;
vector<int> son[N];
void dfs(int x) {
	int tmp = 0;
	for(auto y : son[x]) {
		dfs(y);
		tmp += siz[id[y]];
	}
	id[x] = ++cnt; siz[cnt] = tmp + 1;
}
signed main()
{
	n = read(); m = read();
	for(int i = 2; i <= n; i++) 
		son[read()].push_back(i);
	dfs(1);
	for(int i = 1; i <= m; i++)
		col[id[read()]] = 1;
	f[0][0] = 1;
	for(int i = 1; i <= n; i++) {
		for(int j = 0; j <= m; j++) {
			f[i][j] = (f[i][j] + f[i - 1][j]) % mod;
			if(col[i] == 1 && j) f[i][j] = (f[i][j] + f[i - siz[i]][j - 1]) % mod;
			if(col[i] == 0) f[i][j] = (f[i][j] + f[i - siz[i]][j]) % mod;
		}
	}
	for(int i = 0; i <= m; i++) 
		printf("%d\n", f[n][i]);
	return 0;
}
posted @ 2021-02-06 02:20  _onglu  阅读(62)  评论(0编辑  收藏  举报