Loading

【题解】P3711 仓鼠的数学题

poly 令人晕眩,令人晕眩的 poly.

思路

伯努利数。

首先意识到有一个拉插题也是求自然数幂和,所以答案是关于 \(n\)\(k\) 次多项式。

考虑设出 \(S_{n, k} = \sum\limits_{i = 0}^{n - 1} i^k\),这里不设到 \(n\) 的原因是方便用伯努利数表示,因此最后要记得加上 \(n^k\).

考虑 \(S_{n, k}\) 的 EGF:\(S_n(x) = \sum\limits_{k = 0}^{+ \infty} \frac{x^k} {k!} \sum\limits_{i = 0}^{n - 1} i^k\).

交换求和顺序得 \(S_n(x) = \sum\limits_{i = 0}^{n - 1} \sum\limits_{k = 0}^{+ \infty} \frac{(ix)^k} {k!}\).

其封闭形式为:\(\sum\limits_{i = 0}^{n - 1} (e^x)^i\).

求和得:\(S_n(x) = \frac{e^{nx} - 1}{e^x - 1}\).

根据 【题解】P4464 [国家集训队] JZPKIL,考虑用裂项后用伯努利数表示:\(S_n(x) = \frac{e^{nx} - 1}{x} B(x)\).

整理一下得:\(\frac{S_{n, k}}{k!} = \sum\limits_{i = 0}^k \frac{n^{i + 1}}{(i + 1)!} {B_{k - i}}{(k - i)!}\).

\(S_{n, k}\) 代入原式的 GF 得:\(\sum\limits_{k = 0}^n a_k \sum\limits_{i = 0}^x i^k = \sum\limits_{k = 0}^n a_k (x^k + S_{x, k})\).

将上式代入原式整理得:\(\sum\limits_{k = 0}^n a_k x^k + \sum\limits_{i = 0}^n \frac{x^{i + 1}}{(i + 1)!} \sum\limits_{k = i}^n a_k k! \frac{B_{k - i}}{(k - i)!}\).

右边是差卷积的形式,反转两次 \(O(n \log n)\) 做。

快速求伯努利数可以考虑伯努利数的 EGF:\(B = \frac{x}{e^x - 1}\).

注意 \(e^x - 1\) 的常数项为 \(0\) 不能求逆,简单平移一下再展开:\(B = \frac{1}{\sum\limits_{i = 0}^{+ \infty} \frac{x^i}{(i + 1)!}}\).

就可以直接上求逆 \(O(n \log n)\) 做。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef long long ll;

const int sz = 1e6 + 5;
const int mod = 998244353;
const int g = 3;

int n;
int rev[sz];
ll fac[sz], invf[sz];
ll B[sz], F[sz], inv[sz], wp[sz];
ll Ft[sz], Rt[sz];

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

ll qpow(ll base, ll power, ll mod)
{
    ll res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

void invp(ll *f, ll *r, int n)
{
    int k = 1;
    while (k < n) k <<= 1;
    r[0] = qpow(f[0], mod - 2, mod);
    for (int len = 2, m = 1; len <= k; m = len, len <<= 1)
    {
        for (int i = 0; i < len; i++) Rt[i] = r[i], Ft[i] = f[i];
        NTT(Ft, len), NTT(Rt, len);
        for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
        INTT(Rt, len);
        for (int i = 0; i < m; i++) Rt[i] = 0; Rt[0] = 1;
        for (int i = 0; i < len; i++) Ft[i] = r[i];
        NTT(Ft, len), NTT(Rt, len);
        for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
        INTT(Rt, len);
        for (int i = m; i < len; i++) r[i] = (r[i] * 2ll - Rt[i] + mod) % mod;
    }
    memset(Ft, 0, k * sizeof(ll));
    memset(Rt, 0, k * sizeof(ll));
    for (int i = n; i < k; i++) r[i] = 0;
}

void init(int lim)
{
    fac[0] = invf[0] = fac[1] = invf[1] = 1;
    for (int i = 2; i <= lim; i++) fac[i] = fac[i - 1] * i % mod, invf[i] = (mod - mod / i) * invf[mod % i] % mod;
    for (int i = 2; i <= lim; i++) invf[i] = invf[i - 1] * invf[i] % mod;
}

int main()
{
    scanf("%d", &n);
    init(n + 10);
    for (int i = 0; i <= n; i++) scanf("%lld", &F[i]), F[i] = F[i] * fac[i] % mod;
    // for (int i = 0; i <= n; i++) printf("%lld ", F[i]); putchar('\n');
    // for (int i = 0; i <= n; i++) printf("%lld ", invf[i]); putchar('\n');
    invp(invf + 1, B, n + 1);
    // for (int i = 0; i <= n; i++) printf("%lld ", invf[i]); putchar('\n');
    reverse(F, F + n + 1);
    int k = 1;
    while (k < (n + n + 2)) k <<= 1;
    NTT(F, k), NTT(B, k);
    for (int i = 0; i < k; i++) F[i] = F[i] * B[i] % mod;
    INTT(F, k);
    reverse(F, F + n + 1);
    memset(B, 0, sizeof(B));
    for (int i = 1; i <= n + 1; i++) B[i] = F[i - 1];
    memset(F, 0, sizeof(F));
    for (int i = 0; i <= n + 1; i++) F[i] = invf[i];
    reverse(B, B + n + 2);
    k = 1;
    while (k < (n + n + 4)) k <<= 1;
    NTT(F, k), NTT(B, k);
    for (int i = 0; i < k; i++) F[i] = F[i] * B[i] % mod;
    INTT(F, k);
    reverse(F, F + n + 2);
    for (int i = 0; i <= n + 1; i++) printf("%lld ", F[i] * invf[i] % mod);
    return 0;
}
posted @ 2023-02-12 22:06  kymru  阅读(21)  评论(0编辑  收藏  举报