Loading

【题解】P5219 无聊的水题 I

思路

prufer 序列 + 卷积优化 dp.

首先考虑到令 \(a\) 为原树的 prufer 序列,则 \(\sum\limits_{i = 1}^{n - 2} [a_i = k] = \operatorname{deg}(k)\),其中 \(\operatorname{deg}(k)\)\(k\) 点的度数。

所以可以转化问题:对长度为 \(n - 2\),值域为 \([1, n]\) 且众数出现次数在 \(m\) 以内的序列个数。

对于这类构造序列,并且转移需要知道值出现次数(或和等)的 dp,可以考虑钦定长度,然后依次考虑每个数对当前序列的所有贡献。

\(f[i][j]\) 表示对于长度为 \(i\) 的序列,考虑值域为 \([1, j]\) 时的答案。最后答案稍微容斥一下。

转移考虑枚举 \(j\) 的贡献,得 \(f[i][j] = \sum\limits_{k = 0}^{\min(i, m)} {i \choose k} f[i - k][j - 1]\).

考虑套路地拆开组合数:\(f[i][j] = \sum\limits_{k = 0}^{\min(i, m)} \frac{i!}{k! (i - k)!} f[i - k][j - 1]\).

整理得 \(\frac{f[i][j]}{i!} = \sum\limits_{k = 0}^{\min(i, m)} \frac{f[i - k][j - 1]}{(i - k)!}\).

注意到转移实际上是加法卷积进行若干层,可以考虑用卷积优化。

\(F_i(x) = \sum\limits_{j = 0}^m f[i][j], G(x) = \sum\limits_{i = 0}^n [i \leq m] (i!)\),那么每层的转移就是 \(F * G\).

注意到每层的 \(G\) 都是相同的,并且 \(F\) 的初始状态是零次项为 \(1\) 的零次多项式,可以考虑直接倍增求 \(G\) 的幂。

时间复杂度 \(O(n \log^2 n)\).

注意代码和题解不符(dp 两维交换顺序)

代码

#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long ll;

const int maxn = 4e5 + 5;
const int mod = 998244353;
const int g = 3;

int n, m;
int rev[maxn], fac[maxn], invf[maxn];
ll G[maxn], wp[maxn], bs[maxn], pw[maxn];

ll qpow(ll base, ll power, ll mod)
{
    ll res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

void powp(int n, int pwr, int m)
{
    int k = 1;
    while (k < n) k <<= 1;
    for (int i = 0; i <= m; i++) bs[i] = invf[i];
    for (int i = m + 1; i < k; i++) bs[i] = 0;
    pw[0] = 1; for (int i = 1; i < k; i++) pw[i] = 0;
    while (pwr)
    {
        if (pwr & 1)
        {
            NTT(pw, k), NTT(bs, k);
            for (int i = 0; i < k; i++) pw[i] = pw[i] * bs[i] % mod;
            INTT(pw, k), INTT(bs, k);
            for (int i = n; i < k; i++) pw[i] = 0;
        }
        NTT(bs, k);
        for (int i = 0; i < k; i++) bs[i] = bs[i] * bs[i] % mod;
        INTT(bs, k);
        for (int i = n; i < k; i++) bs[i] = 0;
        pwr >>= 1;
    }
}

void powp(ll *F, int n, int pwr)
{
    int k = 1;
    while (k <= (n << 1)) k <<= 1;
    for (int i = 0; i < n; i++) bs[i] = F[i];
    for (int i = n; i < k; i++) bs[i] = 0;
    pw[0] = 1; for (int i = 1; i < k; i++) pw[i] = 0;
    while (pwr)
    {
        if (pwr & 1)
        {
            NTT(bs, k), NTT(pw, k);
            for (int i = 0; i < k; i++) pw[i] = pw[i] * bs[i] % mod;
            INTT(bs, k), INTT(pw, k);
            for (int i = n; i < k; i++) pw[i] = 0;
        }
        NTT(bs, k);
        for (int i = 0; i < k; i++) bs[i] = bs[i] * bs[i] % mod;
        INTT(bs, k);
        for (int i = n; i < k; i++) bs[i] = 0;
        pwr >>= 1;
    }
}

int solve(int len, int n, int m)
{
    if (m <= 0) return 0;
    int k = 1;
    while (k < n) k <<= 1;
    for (int i = 0; i <= m; i++) G[i] = invf[i];
    for (int i = m + 1; i < k; i++) G[i] = 0;
    powp(G, len + 1, n);
    return pw[len] * fac[len] % mod;
}

int main()
{
    scanf("%d%d", &n, &m);
    fac[0] = invf[0] = fac[1] = invf[1] = 1;
    for (int i = 2; i <= max(n, m); i++) fac[i] = 1ll * fac[i - 1] * i % mod, invf[i] = 1ll * (mod - mod / i) * invf[mod % i] % mod;
    for (int i = 1; i <= max(n, m); i++) invf[i] = 1ll * invf[i - 1] * invf[i] % mod;
    printf("%lld\n", (solve(n - 2, n, m - 1) - solve(n - 2, n, m - 2) + mod) % mod);
    return 0;
}
posted @ 2023-02-03 18:48  kymru  阅读(23)  评论(0编辑  收藏  举报