P6049 燔祭 题解
题意:
计算满足如下条件的带标号有根树数量:
- 这棵树一共有 \(n\) 个节点。
- 每个节点都有一个整数权值,且在区间 \([1,m]\) 内。
- 每个节点的权值都不大于其父节点的权值。
\(n,m \le 400\)
思路:
好题。
对于这种计数问题,肯定第一眼会想到 \(dp\),我们设 \(f_{n,m}\) 表示 \(n\) 个点的有标号树,根的权值是 \(m\) 的方案数,转移则是:
\[f_{n,m} = \sum_{k=0}^{n-1}\sum_{i_1 + i_2 + \dots + i_k = n - 1}\binom{n-1}{i_1,i_2,\dots,i_k}(\sum_{j=0}^{m}f_{i_1,j})(\sum_{j=0}^{m}f_{i_2,j})\dots(\sum_{j=0}^{m}f_{i_k,j})
\]
然后我们发现这是一个标准的用 EGF 优化的形式,不妨设 \(g_{n,m} = \sum_{j = 0}^m f_{n,j}\),\(F_m(x) = \sum_{n}\frac{f_{n,m}x^n}{n!}\) 和 \(G_m(x) = \sum_{n}\frac{g_{n,m}x^n}{n!}\)。
我们就可以得到这样的转移式:
\[F_m(x) = xe^{G_m(x)} = xe^{G_{m-1}(x) + F_m(x)}
\]
回到问题本身,我们发现答案应该是一个关于 \(m\) 的 \(n\) 次多项式,这个可以从上面的 \(dp\) 转移式中看出来,也可以通过发现答案粗略上界是 \((nm)^n\) 看出来。
于是我们只要计算出了 \(f_{n,1} \sim f_{n, n + 1}\) 就可以插值得到 \(f_{n,m}\)。
考虑到多项式 exp 有一种神奇的 \(O(n^2)\) 递推的解法:如果要求 \(B = e^{A}\),两边取导得到:\(B'= BA'\),展开系数得到:
\[(n+1)B_{n+1} = \sum_{i=0}^n(i+1)A_{i+1}B_{n-i}
\]
也就是:
\[B_{n} = \frac{1}{n}\sum_{i=1}^{n}iA_{i}B_{n-i}
\]
于是我们可以直接这样来求,考虑到 \(F = xe^{F}e^{G}\),展开系数得到:
\[F_{n} = \sum_{i=0}^{n-1}F_iG_{(n-1)-i}
\]
于是我们通过逐步递推出 exp 就可以在 \(O(n^3)\) 内解出来。
点击查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 405;
const int mod = 998244353;
int fpow(int a, int b, int p) {
if (b == 0)
return 1;
int ans = fpow(a, b / 2, p);
ans = 1ll * ans * ans % p;
if (b % 2 == 1)
ans = 1ll * a * ans % p;
return ans;
}
int mmi(int a, int p) {
return fpow(a, p - 2, p);
}
int fac[N] = {0}, inv[N] = {0}, inv2[N] = {0};
void init(int n) {
fac[0] = 1;
for (int i = 1; i <= n; i++)
fac[i] = 1ll * i * fac[i - 1] % mod;
inv[n] = mmi(fac[n], mod);
for (int i = n - 1; i >= 0; i--)
inv[i] = 1ll * (i + 1) * inv[i + 1] % mod;
inv2[1] = 1;
for (int i = 2; i <= n; i++)
inv2[i] = 1ll * (mod - mod / i) * inv2[mod % i] % mod;
}
int n, m;
int F[N][N] = {{0}}, G[N][N] = {{0}};
int f[N][N] = {{0}}, g[N][N] = {{0}};
void upd(int i) {//计算 F[i][1 ... n] 的值
//利用 G[i] 来储存 exp G[i - 1]
g[i][0] = 1;
for (int j = 1; j <= n; j++) {
for (int k = 1; k <= j; k++)
g[i][j] = (g[i][j] + 1ll * k * G[i - 1][k] % mod * g[i][j - k] % mod) % mod;
g[i][j] = 1ll * g[i][j] * inv2[j] % mod;
}
//计算 F[i]
f[i][0] = 1;
for (int j = 1; j <= n; j++) {
for (int k = 0; k <= j - 1; k++)
F[i][j] = (F[i][j] + 1ll * f[i][k] * g[i][(j - 1) - k] % mod) % mod;
//递推出 f[i][j]
for (int k = 1; k <= j; k++)
f[i][j] = (f[i][j] + 1ll * k * F[i][k] % mod * f[i][j - k] % mod) % mod;
f[i][j] = 1ll * f[i][j] * inv2[j] % mod;
}
//计算 G[i]
for (int j = 1; j <= n; j++)
G[i][j] = (G[i - 1][j] + F[i][j]) % mod;
}
int prx[N] = {0}, suf[N] = {0};
void Lagrange() {
prx[0] = 1;
for (int i = 1; i <= n + 1; i++)
prx[i] = 1ll * prx[i - 1] * (m - i + mod) % mod;
suf[n + 2] = 1;
for (int i = n + 1; i >= 1; i--)
suf[i] = 1ll * suf[i + 1] * (m - i + mod) % mod;
int ans = 0;
for (int i = 1; i <= n + 1; i++)
if ((n + 1 - i) % 2 == 0)
ans = (ans + 1ll * G[i][n] * prx[i - 1] % mod * suf[i + 1] % mod * inv[i - 1] % mod * inv[n + 1 - i] % mod) % mod;
else
ans = (ans - 1ll * G[i][n] * prx[i - 1] % mod * suf[i + 1] % mod * inv[i - 1] % mod * inv[n + 1 - i] % mod + mod) % mod;
cout << ans << endl;
}
int main() {
cin >> n >> m;
init(n + 1);
//先递推出 F 和 G
//初始 i ^ i-1
for (int i = 1; i <= n; i++)
F[1][i] = G[1][i] = 1ll * fpow(i, i - 1, mod) * inv[i] % mod;
for (int i = 2; i <= n + 1; i++)
upd(i);
//求出 G[1, 2, ..., n + 1][n] 然后插值
for (int i = 1; i <= n + 1; i++)
G[i][n] = 1ll * fac[n] * G[i][n] % mod;
Lagrange();
return 0;
}