仓鼠的数学题——生成函数
题面
解析
需要用到伯努利数
把伯努利数的式子代入问题:$$\begin{align*}\sum_{k=0}^nS_k(x)a_k&=\sum_{k=0}^n\frac{a_k}{k+1}\sum_{i=0}^k\binom{k+1}{i}B_i x^{k+1-i}\\&=\sum_{k=0}^n a_k*k! \sum_{i=0}^k \frac{B_i}{i!} *\frac{x^{k+1-i}}{(k+1-i)!} \end{align*}$$
设$A_i=a_i * i!$,$C_i=\frac{B_i}{i!}$,答案中$x^i$的系数为$\frac{1}{i!}\sum_{k=0}\sum_{j=0}[k+1-j==i]A_kC_j$
设$D_i=C_{n+1-i}$,则$x^i$的系数为$\frac{1}{i!}\sum_{k=0}\sum_{j=0}[k+j==n+i]A_kD_j$
卷积即可
$O(N\log N)$
代码:
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; const int maxn = 500005, mod = 998244353, g = 3; inline int read() { int ret, f=1; char c; while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1; ret=c-'0'; while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0'; return ret*f; } int add(int x, int y) { return x + y < mod? x + y: x + y - mod; } int rdc(int x, int y) { return x - y < 0? x - y + mod: x - y; } ll qpow(ll x, int y) { ll ret = 1; while(y) { if(y&1) ret = ret * x % mod; x = x * x % mod; y >>= 1; } return ret; } int n, lim, bit, rev[maxn<<1]; ll a[maxn<<1], fac[maxn], inv[maxn], b[maxn], B[maxn<<1], c[maxn<<1], ginv; void init() { ginv = qpow(g, mod - 2); fac[0] = 1; for(int i = 1; i <= n + 2; ++i) fac[i] = i * fac[i-1] % mod; inv[n+2] = qpow(fac[n+2], mod - 2); for(int i = n + 1; i >= 0; --i) { inv[i] = inv[i+1] * (i + 1) % mod; b[i] = inv[i+1]; } } void NTT_init(int x) { lim = 1; bit = 0; while(lim <= x) { lim <<= 1; ++ bit; } for(int i = 1; i < lim; ++i) rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1)); } void NTT(ll *x, int y) { for(int i = 1; i < lim; ++i) if(i < rev[i]) swap(x[i], x[rev[i]]); ll wn, w, u, v; for(int i = 1; i < lim; i <<= 1) { wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1)); for(int j = 0; j < lim; j += (i << 1)) { w = 1; for(int k = 0; k < i; ++k) { u = x[j+k]; v = x[j+k+i] * w % mod; x[j+k] = add(u, v); x[j+k+i] = rdc(u, v); w = w * wn % mod; } } } if(y == -1) { ll iv = qpow(lim, mod - 2); for(int i = 0; i < lim; ++i) x[i] = x[i] * iv % mod; } } void get_inv(ll *x, ll *y, int len) { if(len == 1) { x[0] = qpow(y[0], mod - 2); return ; } get_inv(x, y, (len + 1) >> 1); for(int i = 0; i < len; ++i) c[i] = y[i]; NTT_init(len << 1); NTT(x, 1); NTT(c, 1); for(int i = 0; i < lim; ++i) { x[i] = rdc(add(x[i], x[i]), (c[i] * x[i] % mod) * x[i] % mod); c[i] = 0; } NTT(x, -1); for(int i = len; i < lim; ++i) x[i] = 0; } int main() { n = read(); init(); for(int i = 0; i <= n; ++i) a[i] = read() * fac[i] % mod; printf("%lld ", a[0]); get_inv(B, b, n + 2); B[1] = qpow(2, mod - 2); reverse(B, B + n + 2); NTT_init((n << 1) + 5); NTT(a, 1); NTT(B, 1); for(int i = 0; i < lim; ++i) a[i] = a[i] * B[i] % mod; NTT(a, -1); for(int i = 1; i <= n + 1; ++i) printf("%lld ", a[i+n] * inv[i] % mod); return 0; }