P6049 燔祭 题解

题意:

计算满足如下条件的带标号有根树数量:

  • 这棵树一共有 \(n\) 个节点。
  • 每个节点都有一个整数权值,且在区间 \([1,m]\) 内。
  • 每个节点的权值都不大于其父节点的权值。

\(n,m \le 400\)


思路:

好题。

对于这种计数问题,肯定第一眼会想到 \(dp\),我们设 \(f_{n,m}\) 表示 \(n\) 个点的有标号树,根的权值是 \(m\) 的方案数,转移则是:

\[f_{n,m} = \sum_{k=0}^{n-1}\sum_{i_1 + i_2 + \dots + i_k = n - 1}\binom{n-1}{i_1,i_2,\dots,i_k}(\sum_{j=0}^{m}f_{i_1,j})(\sum_{j=0}^{m}f_{i_2,j})\dots(\sum_{j=0}^{m}f_{i_k,j}) \]

然后我们发现这是一个标准的用 EGF 优化的形式,不妨设 \(g_{n,m} = \sum_{j = 0}^m f_{n,j}\)\(F_m(x) = \sum_{n}\frac{f_{n,m}x^n}{n!}\)\(G_m(x) = \sum_{n}\frac{g_{n,m}x^n}{n!}\)

我们就可以得到这样的转移式:

\[F_m(x) = xe^{G_m(x)} = xe^{G_{m-1}(x) + F_m(x)} \]

回到问题本身,我们发现答案应该是一个关于 \(m\)\(n\) 次多项式,这个可以从上面的 \(dp\) 转移式中看出来,也可以通过发现答案粗略上界是 \((nm)^n\) 看出来。

于是我们只要计算出了 \(f_{n,1} \sim f_{n, n + 1}\) 就可以插值得到 \(f_{n,m}\)

考虑到多项式 exp 有一种神奇的 \(O(n^2)\) 递推的解法:如果要求 \(B = e^{A}\),两边取导得到:\(B'= BA'\),展开系数得到:

\[(n+1)B_{n+1} = \sum_{i=0}^n(i+1)A_{i+1}B_{n-i} \]

也就是:

\[B_{n} = \frac{1}{n}\sum_{i=1}^{n}iA_{i}B_{n-i} \]

于是我们可以直接这样来求,考虑到 \(F = xe^{F}e^{G}\),展开系数得到:

\[F_{n} = \sum_{i=0}^{n-1}F_iG_{(n-1)-i} \]

于是我们通过逐步递推出 exp 就可以在 \(O(n^3)\) 内解出来。

点击查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 405;
const int mod = 998244353;

int fpow(int a, int b, int p) {
	if (b == 0)
		return 1;
	int ans = fpow(a, b / 2, p);
	ans = 1ll * ans * ans % p;
	if (b % 2 == 1)
		ans = 1ll * a * ans % p;
	return ans;
}
int mmi(int a, int p) {
	return fpow(a, p - 2, p);
}

int fac[N] = {0}, inv[N] = {0}, inv2[N] = {0};

void init(int n) {
	fac[0] = 1;
	for (int i = 1; i <= n; i++)
		fac[i] = 1ll * i * fac[i - 1] % mod;
	inv[n] = mmi(fac[n], mod);
	for (int i = n - 1; i >= 0; i--)
		inv[i] = 1ll * (i + 1) * inv[i + 1] % mod;
	inv2[1] = 1;
	for (int i = 2; i <= n; i++)
		inv2[i] = 1ll * (mod - mod / i) * inv2[mod % i] % mod;
}

int n, m;

int F[N][N] = {{0}}, G[N][N] = {{0}}; 
int f[N][N] = {{0}}, g[N][N] = {{0}};

void upd(int i) {//计算 F[i][1 ... n] 的值 
	//利用 G[i] 来储存 exp G[i - 1]
	g[i][0] = 1;
	for (int j = 1; j <= n; j++) {
		for (int k = 1; k <= j; k++)
			g[i][j] = (g[i][j] + 1ll * k * G[i - 1][k] % mod * g[i][j - k] % mod) % mod;
		g[i][j] = 1ll * g[i][j] * inv2[j] % mod;
	}
	
	//计算 F[i]
	f[i][0] = 1;
	for (int j = 1; j <= n; j++) {
		for (int k = 0; k <= j - 1; k++)
			F[i][j] = (F[i][j] + 1ll * f[i][k] * g[i][(j - 1) - k] % mod) % mod;
		//递推出 f[i][j]
		for (int k = 1; k <= j; k++)
			f[i][j] = (f[i][j] + 1ll * k * F[i][k] % mod * f[i][j - k] % mod) % mod;
		f[i][j] = 1ll * f[i][j] * inv2[j] % mod;
	} 
	//计算 G[i]
	for (int j = 1; j <= n; j++)
		G[i][j] = (G[i - 1][j] + F[i][j]) % mod; 
}

int prx[N] = {0}, suf[N] = {0};
void Lagrange() {
	prx[0] = 1;
	for (int i = 1; i <= n + 1; i++)
		prx[i] = 1ll * prx[i - 1] * (m - i + mod) % mod;
	suf[n + 2] = 1;
	for (int i = n + 1; i >= 1; i--)
		suf[i] = 1ll * suf[i + 1] * (m - i + mod) % mod;
	int ans = 0;
	for (int i = 1; i <= n + 1; i++) 
		if ((n + 1 - i) % 2 == 0)
			ans = (ans + 1ll * G[i][n] * prx[i - 1] % mod * suf[i + 1] % mod * inv[i - 1] % mod * inv[n + 1 - i] % mod) % mod;
		else
			ans = (ans - 1ll * G[i][n] * prx[i - 1] % mod * suf[i + 1] % mod * inv[i - 1] % mod * inv[n + 1 - i] % mod + mod) % mod;
	cout << ans << endl;
}

int main() {
	cin >> n >> m;
	init(n + 1);
	//先递推出 F 和 G
	//初始 i ^ i-1
	for (int i = 1; i <= n; i++)
		F[1][i] = G[1][i] = 1ll * fpow(i, i - 1, mod) * inv[i] % mod;
	for (int i = 2; i <= n + 1; i++) 
		upd(i);
	//求出 G[1, 2, ..., n + 1][n] 然后插值 
	for (int i = 1; i <= n + 1; i++)
		G[i][n] = 1ll * fac[n] * G[i][n] % mod;
	Lagrange();
	return 0;
} 
posted @ 2024-05-28 21:42  rlc202204  阅读(7)  评论(0编辑  收藏  举报