WC2019 数树

题意略。

问题 0

显然,若原图中有 \(k\) 个连通块,则答案就是 \(y^k\)。至此问题 0 已被解决。

问题 1

考虑应用容斥原理的经典模型。令 \(G_i\) 表示钦定蓝树中有 \(i\) 条边被红树覆盖的方案数,\(F_i\) 为蓝树中恰好有 \(i\) 条边被红树覆盖的方案数,则

\[F_i=\sum\limits_{j=i}^{n-1}(-1)^{j-i}\binom{j}{i}G_j \]

为了消除组合数,将 \(y^{n-i}F_i\) 求和表示答案

\[\mathrm{ans}=\sum\limits_{i=0}^{n-1}y^{n-i}F_i=\sum\limits_{i=0}^{n-1}y^{n-i}\sum\limits_{j=i}^{n-1}(-1)^{j-i}\binom{j}{i}G_j \]

用二项式定理消除组合数可以得到

\[\mathrm{ans}=y^n\sum\limits_{i=0}^{n-1}(y^{-1}-1)^iG_i \]

事实上 \(G_i\) 也很难求,接下来将利用一个广为人知但不常用的推论绕开 \(G_i\)

根据 Prüfer 序列的一个推论,若一个图包含 \(k\) 个大小分别为 \(a_{1..k}\) 的连通块,则该图的生成树个数为 \(\dfrac{\prod\limits_{i=1}^kna_i}{n^2}\)。而对于 \(G_i\) 中一个指定的边覆盖要求,其满足要求的红树个数恰为这个图的生成树个数。于是将答案写作

\[\begin{aligned}\mathrm{ans}&=\dfrac{y^n}{n^2}\sum\limits_{i=0}^{n-1}(y^{-1}-1)^i\sum\prod\limits_{j=1}^{n-i}na_j\\&=\dfrac{(1-y)^n}{n^2}\sum\limits_{i=1}^{n}\left(\dfrac{y}{1-y}\right)^{i}\sum\prod\limits_{j=1}^{i}na_j\\&=\dfrac{(1-y)^n}{n^2}\sum\limits_{i=1}^{n}\sum\prod\limits_{j=1}^{i}\dfrac{yna_j}{1-y}\end{aligned} \]

不难发现,每个大小为 \(a\) 的连通块都能够对答案产生 \(\dfrac{yna}{1-y}\) 的乘积贡献。树形 DP 即可,复杂度 \(O(n)\)

问题 2

此时我们不再枚举单棵树,而是同时枚举两棵树。按照问题 1 的方法,不难得到

\[\mathrm{ans}=\dfrac{(1-y)^n}{n^4}\sum\limits_{i=1}^{n}\sum\prod\limits_{j=1}^{i}\dfrac{yn^2{a_j}^2}{1-y} \]

我们得到了一个这样的问题:

一个大小为 \(a\) 的树的价值为 \(ka^2\),一个森林的价值为其所有树的价值之积。求所有森林的价值之和。

这个问题可以利用指数函数的组合意义求,复杂度 \(O(n\log n)\)

typedef long long ll;
const int mod = 998244353;

int qpow(int a, int b){
	int res = 1;
	while(b){
		if(b & 1) res = (ll)res * a % mod;
		a = (ll)a * a % mod, b >>= 1;
	}
	return res;
}

int n, y, op;

void solve0(void){
	std::vector<std::pair<int, int> > e;
	int ans = n;
	for(int i = 0, x, y; i < 2 * (n - 1); ++i){
		scanf("%d%d", &x, &y);
		e.emplace_back(x, y);
	}
	std::sort(e.begin(), e.end());
	for(int i = 0, j = 0; i < (int)e.size(); i = ++j){
		while(j + 1 < (int)e.size() && e[j] == e[j + 1]) ++j;
		if(i != j) --ans;
	}
	printf("%d\n", qpow(y, ans));
}
  
int coeff;
std::vector<int> dp, un;
std::vector<std::vector<int> > G;

void dfs(int node, int f){
	dp[node] = 0, un[node] = 1;
	for(int e : G[node]) if(e != f){
		dfs(e, node);
		dp[node] = ((ll)dp[node] * un[e] + (ll)dp[e] * un[node]) % mod;
		un[node] = (ll)un[node] * un[e] % mod;
	}
	dp[node] = (dp[node] + (ll)un[node] * coeff) % mod;
	un[node] = (un[node] + dp[node]) % mod;
}

void solve1(void){
	dp.resize(n), un.resize(n), G.resize(n);
	for(int i = 1, x, y; i < n; ++i){
		scanf("%d%d", &x, &y), --x, --y;
		G[x].push_back(y), G[y].push_back(x);
	}
	if(y == 1){
		printf("%d\n", qpow(n, n - 2));
		return;
	}
	coeff = (ll)n * y % mod * qpow(1 + mod - y, mod - 2) % mod, dfs(0, -1);
	int ans = (ll)dp[0] * qpow(1 + mod - y, n) % mod * qpow(qpow(n, 2), mod - 2) % mod;
	printf("%d\n", ans);
}

void solve2(void){
	if(y == 1){
		printf("%d\n", qpow(n, 2 * n - 4));
		return;
	}
	fstdlib::poly f(n + 1);
	int k = (ll)y * qpow(n, 2) % mod * qpow(mod + 1 - y, mod - 2) % mod;
	for(int i = 1; i <= n; ++i) f[i] = (ll)qpow(i, i) % mod * k % mod;
	for(int i = 1, j = 1; i <= n; ++i) j = (ll)j * i % mod, f[i] = (ll)f[i] * qpow(j, mod - 2) % mod;
	f = fstdlib::exp(f);
	int ans = (ll)f[n] * qpow(1 + mod - y, n) % mod * qpow(qpow(n, 4), mod - 2) % mod;
	for(int i = 1; i <= n; ++i) ans = (ll)ans * i % mod;
	printf("%d\n", ans);
}

int main(){
	scanf("%d%d%d", &n, &y, &op);
	if(op == 0) solve0();
	else if(op == 1) solve1();
	else if(op == 2) solve2();
	return 0;
}
posted @ 2021-01-27 10:10  feiko  阅读(92)  评论(0编辑  收藏  举报