[Kattis]redblacktree(树形依赖背包,DP优化)
Source : NAIPC 2018
题意
有棵树,树上有红点和黑点,要选出一系列没有祖孙关系的节点,满足红点恰好有m个,求方案数。
\(n\le2*10^5\)
\(m\le 1000\)
题解
可以用树形背包解决。
\(f[x][j]\)表示以x为根的树里,恰好选了j个红点的方案数。
转移就大力分配红点个数就行。
时间复杂度为\(O(nm^2)\),超时了。
发现自己树形背包一直写假了,实际上树形背包的复杂度应该为\(O(nm)\)。
Solution1
对于每个节点,他的子树大小是有限的,可以证明如果把上限从背包容积m改成\(siz[x]\)就可以把时间复杂度优化到\(O(nm)\)
证明链接
#include <bits/stdc++.h>
#define Mid ((l + r) >> 1)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
const int mod = 1e9 + 7;
int read(){
char c; int num, f = 1;
while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
return f * num;
}
const int N = 2e5 + 9, M = 1e3 + 9;
int f[N][M], n, m, col[N], tmp[N], siz[N];
vector<int> son[N];
void dfs(int x) {
f[x][0] = 1; siz[x] += col[x];
for(auto y : son[x]) {
dfs(y); siz[x] += siz[y];
for(int i = 0; i <= m; i++) tmp[i] = 0;
for(int i = 0; i <= m && i <= siz[x]; i++) if(f[x][i]){
for(int j = 0; i + j <= m && j <= siz[y]; j++) if(f[y][j]){
tmp[i + j] = (tmp[i + j] + 1ll * f[x][i] * f[y][j] % mod) % mod;
}
}
for(int i = 0; i <= m; i++) f[x][i] = tmp[i];
}
if(col[x] == 1) f[x][1] = (1ll * f[x][1] + 1) % mod;
if(col[x] == 0) f[x][0] = (1ll * f[x][0] + 1) % mod;
}
signed main()
{
n = read(); m = read();
for(int i = 2; i <= n; i++)
son[read()].push_back(i);
for(int i = 1; i <= m; i++)
col[read()] = 1;
dfs(1);
for(int i = 0; i <= m; i++)
printf("%d\n", f[1][i]);
return 0;
}
/*
f[x][k]表示以x的子树,取了k个红点的方案数
f[x][a] * f[y][b] = f[x][a+b]
O(n*m^2)
*/
Solution2
树形背包还可以按照dfs序处理。
给树上节点按照后序遍历标号之后,顺序循环变成先处理子节点再处理父亲。
\(f[i][j]\)表示i之前的森林,取了j个红点方案数。
转移可以是:
当i不取时,\(f[i][j]+=f[i-1][j]\)
当i取时,他的子树就都不能取,由于子树是连续一段,按照\(siz[x]\)跳过即可。
\(f[i][j] = f[i][j] + f[i - siz[i]][j - 1](col[i]==1)\)
\(f[i][j] = f[i][j] + f[i - siz[i]][j](col[i]==0)\)
#include <bits/stdc++.h>
#define Mid ((l + r) >> 1)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
int read(){
char c; int num, f = 1;
while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
return f * num;
}
const int mod = 1e9 + 7;
const int N = 2e5 + 1009;
int n, m, f[N][1009], col[N], id[N], siz[N], cnt;
vector<int> son[N];
void dfs(int x) {
int tmp = 0;
for(auto y : son[x]) {
dfs(y);
tmp += siz[id[y]];
}
id[x] = ++cnt; siz[cnt] = tmp + 1;
}
signed main()
{
n = read(); m = read();
for(int i = 2; i <= n; i++)
son[read()].push_back(i);
dfs(1);
for(int i = 1; i <= m; i++)
col[id[read()]] = 1;
f[0][0] = 1;
for(int i = 1; i <= n; i++) {
for(int j = 0; j <= m; j++) {
f[i][j] = (f[i][j] + f[i - 1][j]) % mod;
if(col[i] == 1 && j) f[i][j] = (f[i][j] + f[i - siz[i]][j - 1]) % mod;
if(col[i] == 0) f[i][j] = (f[i][j] + f[i - siz[i]][j]) % mod;
}
}
for(int i = 0; i <= m; i++)
printf("%d\n", f[n][i]);
return 0;
}