【题解】P3711 仓鼠的数学题
poly 令人晕眩,令人晕眩的 poly.
思路
伯努利数。
首先意识到有一个拉插题也是求自然数幂和,所以答案是关于 \(n\) 的 \(k\) 次多项式。
考虑设出 \(S_{n, k} = \sum\limits_{i = 0}^{n - 1} i^k\),这里不设到 \(n\) 的原因是方便用伯努利数表示,因此最后要记得加上 \(n^k\).
考虑 \(S_{n, k}\) 的 EGF:\(S_n(x) = \sum\limits_{k = 0}^{+ \infty} \frac{x^k} {k!} \sum\limits_{i = 0}^{n - 1} i^k\).
交换求和顺序得 \(S_n(x) = \sum\limits_{i = 0}^{n - 1} \sum\limits_{k = 0}^{+ \infty} \frac{(ix)^k} {k!}\).
其封闭形式为:\(\sum\limits_{i = 0}^{n - 1} (e^x)^i\).
求和得:\(S_n(x) = \frac{e^{nx} - 1}{e^x - 1}\).
根据 【题解】P4464 [国家集训队] JZPKIL,考虑用裂项后用伯努利数表示:\(S_n(x) = \frac{e^{nx} - 1}{x} B(x)\).
整理一下得:\(\frac{S_{n, k}}{k!} = \sum\limits_{i = 0}^k \frac{n^{i + 1}}{(i + 1)!} {B_{k - i}}{(k - i)!}\).
将 \(S_{n, k}\) 代入原式的 GF 得:\(\sum\limits_{k = 0}^n a_k \sum\limits_{i = 0}^x i^k = \sum\limits_{k = 0}^n a_k (x^k + S_{x, k})\).
将上式代入原式整理得:\(\sum\limits_{k = 0}^n a_k x^k + \sum\limits_{i = 0}^n \frac{x^{i + 1}}{(i + 1)!} \sum\limits_{k = i}^n a_k k! \frac{B_{k - i}}{(k - i)!}\).
右边是差卷积的形式,反转两次 \(O(n \log n)\) 做。
快速求伯努利数可以考虑伯努利数的 EGF:\(B = \frac{x}{e^x - 1}\).
注意 \(e^x - 1\) 的常数项为 \(0\) 不能求逆,简单平移一下再展开:\(B = \frac{1}{\sum\limits_{i = 0}^{+ \infty} \frac{x^i}{(i + 1)!}}\).
就可以直接上求逆 \(O(n \log n)\) 做。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int sz = 1e6 + 5;
const int mod = 998244353;
const int g = 3;
int n;
int rev[sz];
ll fac[sz], invf[sz];
ll B[sz], F[sz], inv[sz], wp[sz];
ll Ft[sz], Rt[sz];
void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
ll qpow(ll base, ll power, ll mod)
{
ll res = 1;
while (power)
{
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
void NTT(ll *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
ll wn = qpow(g, (mod - 1) / len, mod);
wp[0] = 1;
for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
int w = 0;
for (int p = l; p < l + m; p++, w++)
{
ll x = A[p], y = wp[w] * A[p + m] % mod;
A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
}
}
}
}
void INTT(ll *A, int n)
{
NTT(A, n);
reverse(A + 1, A + n);
int inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}
void invp(ll *f, ll *r, int n)
{
int k = 1;
while (k < n) k <<= 1;
r[0] = qpow(f[0], mod - 2, mod);
for (int len = 2, m = 1; len <= k; m = len, len <<= 1)
{
for (int i = 0; i < len; i++) Rt[i] = r[i], Ft[i] = f[i];
NTT(Ft, len), NTT(Rt, len);
for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
INTT(Rt, len);
for (int i = 0; i < m; i++) Rt[i] = 0; Rt[0] = 1;
for (int i = 0; i < len; i++) Ft[i] = r[i];
NTT(Ft, len), NTT(Rt, len);
for (int i = 0; i < len; i++) Rt[i] = Rt[i] * Ft[i] % mod;
INTT(Rt, len);
for (int i = m; i < len; i++) r[i] = (r[i] * 2ll - Rt[i] + mod) % mod;
}
memset(Ft, 0, k * sizeof(ll));
memset(Rt, 0, k * sizeof(ll));
for (int i = n; i < k; i++) r[i] = 0;
}
void init(int lim)
{
fac[0] = invf[0] = fac[1] = invf[1] = 1;
for (int i = 2; i <= lim; i++) fac[i] = fac[i - 1] * i % mod, invf[i] = (mod - mod / i) * invf[mod % i] % mod;
for (int i = 2; i <= lim; i++) invf[i] = invf[i - 1] * invf[i] % mod;
}
int main()
{
scanf("%d", &n);
init(n + 10);
for (int i = 0; i <= n; i++) scanf("%lld", &F[i]), F[i] = F[i] * fac[i] % mod;
// for (int i = 0; i <= n; i++) printf("%lld ", F[i]); putchar('\n');
// for (int i = 0; i <= n; i++) printf("%lld ", invf[i]); putchar('\n');
invp(invf + 1, B, n + 1);
// for (int i = 0; i <= n; i++) printf("%lld ", invf[i]); putchar('\n');
reverse(F, F + n + 1);
int k = 1;
while (k < (n + n + 2)) k <<= 1;
NTT(F, k), NTT(B, k);
for (int i = 0; i < k; i++) F[i] = F[i] * B[i] % mod;
INTT(F, k);
reverse(F, F + n + 1);
memset(B, 0, sizeof(B));
for (int i = 1; i <= n + 1; i++) B[i] = F[i - 1];
memset(F, 0, sizeof(F));
for (int i = 0; i <= n + 1; i++) F[i] = invf[i];
reverse(B, B + n + 2);
k = 1;
while (k < (n + n + 4)) k <<= 1;
NTT(F, k), NTT(B, k);
for (int i = 0; i < k; i++) F[i] = F[i] * B[i] % mod;
INTT(F, k);
reverse(F, F + n + 2);
for (int i = 0; i <= n + 1; i++) printf("%lld ", F[i] * invf[i] % mod);
return 0;
}