Loading

【题解】P6667 [清华集训2016] 如何优雅地求和

orz fjy666 orz fjy666 orz fjy666

神 · fjy666 · 拉普拉斯 · 爱因斯坦大帝于 5min 内爆切了此题,膜拜!

思路

斯特林数。

注意到 \(f(k)\) 是多项式而式子中含有组合数,于是考虑到普通多项式转下降幂多项式。

不妨设 \(f(k) = \sum\limits_{i = 0}^m f_i k^i = \sum\limits_{i = 0}^m b_i k^{\underline{i}}\).

我们知道当 \(m\) 较小时可以 \(O(m^2)\) 通过 \(f_i\) 求出 \(b_i\),但是这里给出的是多项式的点值。

暂且忽略,尝试用下降幂多项式化简。

代入多项式有:\(Q(f, n, x) = \sum\limits_{k = 0}^n f(k) \binom{n}{k} x^k (1 - x)^{n - k} \sum\limits_{i = 0}^m b_i k^{\underline{i}}\).

整理成多项式的形式是:\(\sum\limits_{i = 0}^m b_i \sum\limits_{k = 0}^n k^{\underline{i}} \binom{n}{k} x^k (1 - x)^{n - k}\).

注意到下降幂和组合数有结论:\(\binom{n}{k} k^{\underline{i}} = n^{\underline{i}} \binom{n - i}{k - i}\).

于是原式等价于 \(\sum\limits_{i = 0}^m b_i \sum\limits_{k = 0}^n n^{\underline{i}} \binom{n - i}{k - i} x^k (1 - x)^{n - k}\).

整理一下得 \(\sum\limits_{i = 0}^m b_i n^{\underline{i}} \sum\limits_{k = 0}^n \binom{n - i}{k - i} x^k (1 - x)^{n - k}\).

后面的和式可以换种写法:\(\sum\limits_{i = 0}^m b_i n^{\underline{i}} \sum\limits_{k = 0}^{n - i} \binom{n - i}{k} x^{k + i} (1 - x)^{n - k - i}\).

提出来 \(x^k\) 就可以二项式定理:\(\sum\limits_{i = 0}^m b_i n^{\underline{i}} x^k \sum\limits_{k = 0}^{n - i} \binom{n - i}{k} x^i (1 - x)^{n - k - i} = \sum\limits_{i = 0}^m b_i n^{\underline{i}} x^k (1 - x + x)^{n - k} = \sum\limits_{i = 0}^m b_i n^{\underline{i}} x^k\).

于是只需要求出 \(b\).

考虑多项式 \(f\) 点值的 EGF:\(\sum\limits_{n} \frac{f(n) x^n}{n!}\).

\(f(n)\) 代入得 \(\sum\limits_{n} \frac{x^n}{n!} \sum\limits_{i = 0}^m b_i n^{\underline{i}}\).

注意到 \(n^{\underline{i}} n! = \frac{n!}{(n - i)!} \frac{1}{n!} = \frac{1}{(n - i)!}\).

于是有原式等于 \(\sum\limits_{i = 0}^m b_i \sum\limits_{n} \frac{x^n}{n!} n^{\underline{i}} = \sum\limits_{i = 0}^m b_i x^i \sum\limits_{n} \frac{x^{n - i}}{(n - i)!}\).

明显是两个多项式的卷积形式,于是 \(O(n \log n)\) 可以解决。

代码

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;

const int maxm = 2e4 + 5;
const int sz = 2e5 + 5;
const int mod = 998244353;
const int g = 3;

int n, m, x;
int a[maxm], rev[sz];
ll F[sz], G[sz];
ll fac[sz], invf[sz], wp[sz];

inline int read()
{
    int res = 0, flag = 1;
    char ch = getchar();
    while ((ch < '0') || (ch > '9'))
    {
        if (ch == '-') flag = -1;
        ch = getchar();
    }
    while ((ch >= '0') && (ch <= '9')) res = res * 10 + ch - '0', ch = getchar();
    return res * flag;
}

void calc_rev(int k) { for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

ll qpow(ll base, ll power, ll mod)
{
    ll res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod, power >>= 1;
    }
    return res;
}

ll qpow(int x) { return (x % 2 == 0 ? 1 : mod - 1); }

void NTT(ll* A, int n)
{
    calc_rev(n);
    for (int i = 0; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod) % mod;
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n), reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = A[i] * inv % mod;
}

void times(ll *F, ll *G, int lf, int lg, int lim)
{
    int len = lf + lg - 1, k = 1;
    while (k < len) k <<= 1;
    NTT(F, k), NTT(G, k);
    for (int i = 0; i < k; i++) F[i] = F[i] * G[i] % mod;
    INTT(F, k), INTT(G, k);
    for (int i = lim; i < k; i++) F[i] = 0;
}

void init(int lim)
{
    fac[0] = invf[0] = fac[1] = invf[1] = 1;
    for (int i = 2; i <= lim; i++) fac[i] = fac[i - 1] * i % mod, invf[i] = (mod - mod / i) * invf[mod % i] % mod;
    for (int i = 2; i <= lim; i++) invf[i] = invf[i - 1] * invf[i] % mod;
}

int main()
{
    scanf("%d%d%d", &n, &m, &x);
    init(m);
    for (int i = 0; i <= m; i++) F[i] = invf[i] * qpow(i) % mod;
    for (int i = 0; i <= m; i++) G[i] = invf[i] * read() % mod;
    times(F, G, m + 1, m + 1, m + 1);
    for (int i = 0; i <= m; i++) F[i] = F[i] * fac[i] % mod;
    int cur = 1, ans = 0;
    for (int i = 0; i <= m; i++) ans = (ans + 1ll * cur * F[i] % mod * invf[i] % mod) % mod, cur = 1ll * cur * x % mod * (n - i) % mod;
    printf("%d\n", ans);
    return 0;
}
posted @ 2023-03-15 18:13  kymru  阅读(86)  评论(0编辑  收藏  举报