LG8935 [JRKSJ R7] 茎【DP】
给定一棵 \(n\) 个点的根节点为 \(1\) 的有根树,现在你要对这棵树进行剪枝,每次你可以选择一个还未被剪掉的节点 \(u\) 进行操作,然后剪掉 \(u\) 的子树所有点(包括 \(u\))。当且仅当你剪掉 \(1\) 时,操作停止。
再给定 \(x,k\),求有多少种不同的操作序列满足第 \(k\) 次恰好操作的是 \(x\)。答案对 \(10^9+7\) 取模。
\(1 \le n \le 5 \times 10^3\)。
考虑 \(1 \sim x\) 的链,对于链上的每个点我们可以 DP 出 \(f_{u,i}\) 表示:考虑所有 \(u\) 的不在链上的子树,操作 \(i\) 次(不一定删完)的操作序列数。只有两条限制:
- 链上的点必须从 \(x\) 依次往上选。
- 对于链上的点 \(u\),其不在链上的子树中的所有操作都必须在它之前出现。
枚举做法可以发现按 \(1 \sim x\) 的顺序 DP 比较方便,因为只需要保证链上新插入的点位置在上一个之前。具体来说设 \(g_{u,i}\) 表示在当前操作序列中,上一个被选的链上的点之前有 \(i\) 个操作,转移考虑当前点选不选,和不在链上的子树一起塞进操作序列里就行,答案就是 \(g_{x,k-1}\)。注意 \(1\) 和 \(x\) 不能不选。容易利用前缀和优化至 \(\mathcal{O}(n^2)\)。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
#define fi first
#define se second
#define all(x) x.begin(), x.end()
constexpr int N = 5e3 + 5, mod = 1e9 + 7;
bool Mbe;
int n, k, x, c[N][N], f[N][N], g[N], tmp[N], siz[N], tot, key[N];
vector <int> e[N], v;
void add(int &x, int y) {
x = x + y >= mod ? x + y - mod : x + y;
}
void dfs(int u, int fa) {
f[u][0] = 1;
for (auto v : e[u]) {
if (v == fa) continue;
dfs(v, u);
key[u] |= key[v];
if (!key[v]) {
for (int i = 0; i <= siz[u]; i++) tmp[i] = f[u][i], f[u][i] = 0;
for (int i = 0; i <= siz[u]; i++)
for (int j = 0; j <= siz[v]; j++)
add(f[u][i + j], 1LL * tmp[i] * f[v][j] % mod * c[i + j][i] % mod);
siz[u] += siz[v];
}
}
if (!key[u]) {
for (int i = siz[u]; i >= 0; i--) add(f[u][i + 1], f[u][i]);
siz[u]++;
} else v.push_back(u);
}
void mian() {
cin >> n >> k >> x;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 0; i <= n; i++) {
c[i][0] = 1;
for (int j = 1; j <= i; j++)
c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
key[x] = 1;
dfs(1, 0);
reverse(all(v));
g[0] = 1;
for (int u : v) {
int sum = 0;
for (int i = tot; i >= 0; i--) {
add(sum, g[i]);
g[i] = (((u == 1 || u == x) ? 0 : g[i]) + sum) % mod;
}
for (int i = 0; i <= tot; i++) tmp[i] = g[i], g[i] = 0;
for (int i = 0; i <= tot; i++)
for (int j = 0; j <= siz[u]; j++)
add(g[i + j], 1LL * tmp[i] * f[u][j] % mod * c[i + j][j] % mod);
tot += siz[u];
}
cout << g[k - 1] << "\n";
}
bool Med;
int main() {
// fprintf(stderr, "%.9lfMb\n", 1.0 * (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int t = 1;
while (t--) mian();
// cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n";
return 0;
}