【题解】P5219 无聊的水题 I
思路
prufer 序列 + 卷积优化 dp.
首先考虑到令 \(a\) 为原树的 prufer 序列,则 \(\sum\limits_{i = 1}^{n - 2} [a_i = k] = \operatorname{deg}(k)\),其中 \(\operatorname{deg}(k)\) 为 \(k\) 点的度数。
所以可以转化问题:对长度为 \(n - 2\),值域为 \([1, n]\) 且众数出现次数在 \(m\) 以内的序列个数。
对于这类构造序列,并且转移需要知道值出现次数(或和等)的 dp,可以考虑钦定长度,然后依次考虑每个数对当前序列的所有贡献。
令 \(f[i][j]\) 表示对于长度为 \(i\) 的序列,考虑值域为 \([1, j]\) 时的答案。最后答案稍微容斥一下。
转移考虑枚举 \(j\) 的贡献,得 \(f[i][j] = \sum\limits_{k = 0}^{\min(i, m)} {i \choose k} f[i - k][j - 1]\).
考虑套路地拆开组合数:\(f[i][j] = \sum\limits_{k = 0}^{\min(i, m)} \frac{i!}{k! (i - k)!} f[i - k][j - 1]\).
整理得 \(\frac{f[i][j]}{i!} = \sum\limits_{k = 0}^{\min(i, m)} \frac{f[i - k][j - 1]}{(i - k)!}\).
注意到转移实际上是加法卷积进行若干层,可以考虑用卷积优化。
令 \(F_i(x) = \sum\limits_{j = 0}^m f[i][j], G(x) = \sum\limits_{i = 0}^n [i \leq m] (i!)\),那么每层的转移就是 \(F * G\).
注意到每层的 \(G\) 都是相同的,并且 \(F\) 的初始状态是零次项为 \(1\) 的零次多项式,可以考虑直接倍增求 \(G\) 的幂。
时间复杂度 \(O(n \log^2 n)\).
注意代码和题解不符(dp 两维交换顺序)
代码
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 4e5 + 5;
const int mod = 998244353;
const int g = 3;
int n, m;
int rev[maxn], fac[maxn], invf[maxn];
ll G[maxn], wp[maxn], bs[maxn], pw[maxn];
ll qpow(ll base, ll power, ll mod)
{
ll res = 1;
while (power)
{
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
void NTT(ll *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
ll wn = qpow(g, (mod - 1) / len, mod);
wp[0] = 1;
for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
int w = 0;
for (int p = l; p < l + m; p++, w++)
{
ll x = A[p], y = wp[w] * A[p + m] % mod;
A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
}
}
}
}
void INTT(ll *A, int n)
{
NTT(A, n);
reverse(A + 1, A + n);
int inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}
void powp(int n, int pwr, int m)
{
int k = 1;
while (k < n) k <<= 1;
for (int i = 0; i <= m; i++) bs[i] = invf[i];
for (int i = m + 1; i < k; i++) bs[i] = 0;
pw[0] = 1; for (int i = 1; i < k; i++) pw[i] = 0;
while (pwr)
{
if (pwr & 1)
{
NTT(pw, k), NTT(bs, k);
for (int i = 0; i < k; i++) pw[i] = pw[i] * bs[i] % mod;
INTT(pw, k), INTT(bs, k);
for (int i = n; i < k; i++) pw[i] = 0;
}
NTT(bs, k);
for (int i = 0; i < k; i++) bs[i] = bs[i] * bs[i] % mod;
INTT(bs, k);
for (int i = n; i < k; i++) bs[i] = 0;
pwr >>= 1;
}
}
void powp(ll *F, int n, int pwr)
{
int k = 1;
while (k <= (n << 1)) k <<= 1;
for (int i = 0; i < n; i++) bs[i] = F[i];
for (int i = n; i < k; i++) bs[i] = 0;
pw[0] = 1; for (int i = 1; i < k; i++) pw[i] = 0;
while (pwr)
{
if (pwr & 1)
{
NTT(bs, k), NTT(pw, k);
for (int i = 0; i < k; i++) pw[i] = pw[i] * bs[i] % mod;
INTT(bs, k), INTT(pw, k);
for (int i = n; i < k; i++) pw[i] = 0;
}
NTT(bs, k);
for (int i = 0; i < k; i++) bs[i] = bs[i] * bs[i] % mod;
INTT(bs, k);
for (int i = n; i < k; i++) bs[i] = 0;
pwr >>= 1;
}
}
int solve(int len, int n, int m)
{
if (m <= 0) return 0;
int k = 1;
while (k < n) k <<= 1;
for (int i = 0; i <= m; i++) G[i] = invf[i];
for (int i = m + 1; i < k; i++) G[i] = 0;
powp(G, len + 1, n);
return pw[len] * fac[len] % mod;
}
int main()
{
scanf("%d%d", &n, &m);
fac[0] = invf[0] = fac[1] = invf[1] = 1;
for (int i = 2; i <= max(n, m); i++) fac[i] = 1ll * fac[i - 1] * i % mod, invf[i] = 1ll * (mod - mod / i) * invf[mod % i] % mod;
for (int i = 1; i <= max(n, m); i++) invf[i] = 1ll * invf[i - 1] * invf[i] % mod;
printf("%lld\n", (solve(n - 2, n, m - 1) - solve(n - 2, n, m - 2) + mod) % mod);
return 0;
}