LG8935 [JRKSJ R7] 茎【DP】

给定一棵 \(n\) 个点的根节点为 \(1\) 的有根树,现在你要对这棵树进行剪枝,每次你可以选择一个还未被剪掉的节点 \(u\) 进行操作,然后剪掉 \(u\) 的子树所有点(包括 \(u\))。当且仅当你剪掉 \(1\) 时,操作停止。

再给定 \(x,k\),求有多少种不同的操作序列满足第 \(k\) 次恰好操作的是 \(x\)。答案对 \(10^9+7\) 取模。

\(1 \le n \le 5 \times 10^3\)


考虑 \(1 \sim x\) 的链,对于链上的每个点我们可以 DP 出 \(f_{u,i}\) 表示:考虑所有 \(u\) 的不在链上的子树,操作 \(i\) 次(不一定删完)的操作序列数。只有两条限制:

  1. 链上的点必须从 \(x\) 依次往上选。
  2. 对于链上的点 \(u\),其不在链上的子树中的所有操作都必须在它之前出现。

枚举做法可以发现按 \(1 \sim x\) 的顺序 DP 比较方便,因为只需要保证链上新插入的点位置在上一个之前。具体来说设 \(g_{u,i}\) 表示在当前操作序列中,上一个被选的链上的点之前有 \(i\) 个操作,转移考虑当前点选不选,和不在链上的子树一起塞进操作序列里就行,答案就是 \(g_{x,k-1}\)。注意 \(1\)\(x\) 不能不选。容易利用前缀和优化至 \(\mathcal{O}(n^2)\)

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int, int> pi;
#define fi first
#define se second
#define all(x) x.begin(), x.end()
constexpr int N = 5e3 + 5, mod = 1e9 + 7;
bool Mbe;
int n, k, x, c[N][N], f[N][N], g[N], tmp[N], siz[N], tot, key[N];
vector <int> e[N], v;
void add(int &x, int y) {
	x = x + y >= mod ? x + y - mod : x + y;
}
void dfs(int u, int fa) {
	f[u][0] = 1;
	for (auto v : e[u]) {
		if (v == fa) continue;
		dfs(v, u);
		key[u] |= key[v];
		if (!key[v]) {
			for (int i = 0; i <= siz[u]; i++) tmp[i] = f[u][i], f[u][i] = 0;
			for (int i = 0; i <= siz[u]; i++)
				for (int j = 0; j <= siz[v]; j++)
					add(f[u][i + j], 1LL * tmp[i] * f[v][j] % mod * c[i + j][i] % mod);
			siz[u] += siz[v];
		}
	}
	if (!key[u]) {
		for (int i = siz[u]; i >= 0; i--) add(f[u][i + 1], f[u][i]);
		siz[u]++;  
	} else v.push_back(u);
}
void mian() {
	cin >> n >> k >> x;
	for (int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		e[u].push_back(v);
		e[v].push_back(u);
	}
	for (int i = 0; i <= n; i++) {
		c[i][0] = 1;
		for (int j = 1; j <= i; j++)
			c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod; 
	}
	key[x] = 1;
	dfs(1, 0);
	reverse(all(v));
	g[0] = 1;
	for (int u : v) {
		int sum = 0;
		for (int i = tot; i >= 0; i--) {
			add(sum, g[i]);
			g[i] = (((u == 1 || u == x) ? 0 : g[i]) + sum) % mod;
		}
		for (int i = 0; i <= tot; i++) tmp[i] = g[i], g[i] = 0;
		for (int i = 0; i <= tot; i++)
			for (int j = 0; j <= siz[u]; j++) 
				add(g[i + j], 1LL * tmp[i] * f[u][j] % mod * c[i + j][j] % mod);
		tot += siz[u];
	}
	cout << g[k - 1] << "\n"; 
}
bool Med; 
int main() {
//	fprintf(stderr, "%.9lfMb\n", 1.0 * (&Mbe - &Med) / 1048576.0);
	ios :: sync_with_stdio(false);
	cin.tie(0), cout.tie(0); 
	int t = 1; 
	while (t--) mian();
//	cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n"; 
	return 0;
} 
posted @ 2024-02-27 20:02  came11ia  阅读(31)  评论(0编辑  收藏  举报