[JSOI2018] 潜入行动
毒瘤树形 \(dp\) 啊。。
题目大意
给定一棵树,如果在一个节点 \(u\) 放一个监听设备,那么与 \(u\) 相邻的节点 \(v\) 都能被监听,但是 \(u\) 不会被监听。现在有 \(m\) 个监听设备,问能够监听整棵树且恰好使用了 \(m\) 个监听设备的方案数。
一个节点至多放一个监听设备。
题目分析
借用大佬 @GKxx 的图:
令 \(dp[i,j,0/1,0/1]\) 表示以 \(i\) 为根的子树中一共放了 \(j\) 装置,\(i\) 点是否放了装置,\(i\) 点有没有被覆盖到的方案数。
来推一推状态转移方程:
-
当 \(i\) 没被监听且没放监听装置时,\(v(v\in son\{i\})\) 一定没有装监听设备,有 \(dp[i,j+k,0,0]=\sum (dp[i,j,0,0]\times dp[v,k,0,1])\)。
-
当 \(i\) 没被监听但放了监听装置时,\(v\)(\(v\) 的定义如上)是否被监听不重要但是一定没有放装置,否则这里就可以被监听了,有 \(dp[i,j+k,1,0]=\sum(dp[i,j,1,0]\times(dp[v,k,0,0]+dp[v,k,0,1]))\)。
-
当 \(i\) 没放装置但被监听了时,分两种情况:
\(i\) 侧的状态为 \(dp[i,j,0,1]\):\(dp[i,j+k,0,1]=\sum(dp[i,j,0,1]\times(dp[v,k,0,1]+dp[v,k,1,1]))\)。
\(i\) 侧的状态为 \(dp[i,j,0,0]\):\(dp[i,j+k,0,1]=\sum(dp[i,j,0,0]\times dp[v,k,1,1])\)。注意这里没有 \(dp[v,k,1,0]\),因为一定要让 \(v\) 被覆盖,但是 \(i\) 没放装置。
- 当 \(i\) 没放放了装置也被窃听了时:
\(i\) 侧是状态 \(dp[i,j,1,0]\):\(dp[i,j+k,1,1]=\sum(dp[i,j,1,0]\times(dp[v,k,1,0]+dp[v,k,1,1]))\)。
\(i\) 侧是状态 \(dp[i,j,1,1]\):\(dp[i,j+k,1,1]=\sum(dp[i,j,1,1]\times(dp[v,k,0,0]+dp[v,k,0,1]+dp[v,k,1,0]+dp[v,k,1,1]))\)。
代码
代码太难调了,调代码时间比思考时间还长。。
来张图了解下:
上面这张图的写法太难调了,所以换了一种写法过了。
注意全开 long long
会 \(\verb!MLE!\),只能在过程中用 long long
。
//2022/4/12
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstdio>
#include <climits>//need "INT_MAX","INT_MIN"
#include <cstring>//need "memset"
#include <numeric>
#include <algorithm>
#define enter putchar(10)
#define debug(c,que) cerr << #c << " = " << c << que
#define cek(c) puts(c)
#define blow(arr,st,ed,w) for(register int i = (st);i <= (ed); ++ i) cout << arr[i] << w;
#define speed_up() ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define mst(a,k) memset(a,k,sizeof(a))
#define Abs(x) ((x) > 0 ? (x) : -(x))
#define stop return(0)
const long long mod = 1e9 + 7;
inline int MOD(long long x) {
if(x < 0) x += mod;
return x % mod;
}
namespace Newstd {
char buf[1 << 21],*p1 = buf,*p2 = buf;
inline int getc() {
return p1 == p2 && (p2 = (p1 = buf) + fread(buf,1,1 << 21,stdin),p1 == p2) ? EOF : *p1 ++;
}
inline int read() {
int ret = 0,f = 0;char ch = getc();
while (!isdigit(ch)) {
if(ch == '-') f = 1;
ch = getc();
}
while (isdigit(ch)) {
ret = (ret << 3) + (ret << 1) + ch - 48;
ch = getc();
}
return f ? -ret : ret;
}
inline void write(int x) {
if(x < 0) {
putchar('-');
x = -x;
}
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace Newstd;
using namespace std;
const int N = 1e5 + 5,M = 105;
struct Gragh {
int v,nxt;
} gra[N << 1];
int head[N],siz[N],dp[N][M][2][2],tmp[M][2][2];
//dp[i,j,0/1,0/1]:以 i 为根的子树中共放了 j 个监听装置,其中 i 点放没放装置,i 点有没有被监听到的方案数
int n,m,idx;
inline void add(int u,int v) {
gra[++ idx].v = v,gra[idx].nxt = head[u],head[u] = idx;
}
inline void dfs(int now,int fath) {
siz[now] = dp[now][0][0][0] = dp[now][1][1][0] = 1;//没有覆盖是一种情况
for (register int i = head[now];i;i = gra[i].nxt) {
int v = gra[i].v;
if (v != fath) {
dfs(v,now);
for (register int j = 0;j <= min(siz[now],m); ++ j) {
tmp[j][0][0] = dp[now][j][0][0],tmp[j][0][1] = dp[now][j][0][1];
tmp[j][1][0] = dp[now][j][1][0],tmp[j][1][1] = dp[now][j][1][1];
dp[now][j][0][0] = dp[now][j][0][1] = dp[now][j][1][0] = dp[now][j][1][1] = 0;
}
for (register int j = 0;j <= min(siz[now],m); ++ j) {
for (register int k = 0;k <= min(siz[v],m - j); ++ k) {
dp[now][j + k][0][0] = MOD(dp[now][j + k][0][0] + MOD(1ll * tmp[j][0][0] * dp[v][k][0][1]));
dp[now][j + k][1][0] = MOD(dp[now][j + k][1][0] + MOD(1ll * tmp[j][1][0] * MOD(1ll * dp[v][k][0][0] + 1ll * dp[v][k][0][1])));
dp[now][j + k][0][1] = MOD(dp[now][j + k][0][1] + MOD(1ll * tmp[j][0][1] * MOD(1ll * dp[v][k][0][1] + 1ll * dp[v][k][1][1])));
dp[now][j + k][0][1] = MOD(dp[now][j + k][0][1] + MOD(1ll * tmp[j][0][0] * dp[v][k][1][1]));
dp[now][j + k][1][1] = MOD(dp[now][j + k][1][1] + MOD(1ll * tmp[j][1][0] * MOD(1ll * dp[v][k][1][1] + 1ll * dp[v][k][1][0])));
dp[now][j + k][1][1] = MOD(dp[now][j + k][1][1] + MOD(1ll * tmp[j][1][1] * (MOD(1ll * dp[v][k][0][0] + 1ll * dp[v][k][0][1]) + MOD(1ll * dp[v][k][1][0] + 1ll * dp[v][k][1][1]))));
}
}
siz[now] += siz[v];
}
}
}
int main(void) {
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
#endif
scanf("%d%d",&n,&m);
for (register int i = 1;i < n; ++ i) {
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs(1,0);
printf("%d\n",MOD(dp[1][m][0][1] + dp[1][m][1][1]));
return 0;
}