GMOJ7336. 整数的拆分
题目大意
给出 \(n,m,k\),定义 \(f(z)=\sum\limits_{i\ge0}i^mz^i\)。求 \([z^n]f^k(z)\)。
\(n\le 10^7, m,k\le 10^3\).
解题思路
对于 \(i^m\),大致处理思路有:1. 斯特林数反演。2. \(m + 1\) 次差分后变成一个 \(\deg \le m\)的多项式。
这道题我采用了方法2。
大致思路:先差分,然后做 NTT ,最后前缀和回来。
首先求出 \(f\) 的前 \(m\) 项,暴力求差分即可。这一部分是 \(O(m^2)\) 的。设差分后的函数为 \(g\)。
然后求出 \(g^k\)。这个是 \(O(mk\log mk)\)的。
然后考虑卷上一个 \(\frac 1 {(1-z)^{(m+1)k}}\) ,将差分前缀和回来。注意到我们只需要一个点值,直接 \(O(\min(n,mk))\) 计算即可。
总时间复杂度是 \(O(m^2+mk+\min(n,mk))\) 的。
#include <cstdio>
#include <cstring>
#include <algorithm>
#pragma GCC optimize ("Ofast")
#define ll long long
#define fo(i, a, b) for(int i = (a); i <= (b); ++i)
#define fd(i, a, b) for(int i = (a); i >= (b); --i)
using namespace std;
const int N = 1e7, M = 2.5e6;
const int mod = 998244353, G = 3, Gi = 332748118;
int n, m, k, inv, fac[N + 10] = {1}, ifc[N + 10];
int lim, L, a[M + 10], r[M + 10];
inline int qpow(int a, int b) {int ret = 1; for(; b; b >>= 1, a = (ll)a * a % mod) b & 1 && (ret = (ll)ret * a % mod); return ret;}
inline void reset(int n) {
lim = 1, L = 0;
while(lim <= n) lim <<= 1, ++L;
inv = qpow(lim, mod - 2);
for(int i = 0; i < lim; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
inline void NTT(int *A, int op) {
static int W[N] = {1};
for(int i = 0; i < lim; ++i)
if(i < r[i]) swap(A[i], A[r[i]]);
for(int mid = 1; mid < lim; mid <<= 1) {
int Wn = qpow(op == 1 ? G : Gi, (mod - 1) / (mid << 1));
for(int i = 1; i < mid; ++i) W[i] = (ll)W[i - 1] * Wn % mod;
for(int j = 0; j < lim; j += (mid << 1))
for(int k = 0; k < mid; ++k) {
int x = A[j + k], y = (ll)W[k] * A[j + k + mid] % mod;
A[j + k] = (x + y) % mod;
A[j + k + mid] = (x - y + mod) % mod;
}
}
}
inline int C(int n, int m) {return n < 0 || n > m ? 0 : (ll)fac[m] * ifc[m - n] % mod * ifc[n] % mod;}
inline int b(int t) {return C(t, t + (m + 1) * k - 1);}
int main() {
freopen("split.in", "r", stdin);
freopen("split.out", "w", stdout);
scanf("%d %d %d", &n, &m, &k);
fo(i, 1, N) fac[i] = (ll)fac[i - 1] * i % mod;
ifc[N] = qpow(fac[N], mod - 2);
fd(i, N, 1) ifc[i - 1] = (ll)ifc[i] * i % mod;
fo(i, 0, m) a[i] = qpow(i, m);
fo(i, 0, m)
fd(j, m, 1) a[j] = (a[j] - a[j - 1] + mod) % mod;
reset(m * k);
NTT(a, 1);
for(int i = 0; i < lim; ++i) a[i] = qpow(a[i], k);
NTT(a, -1);
int ans = 0, lim = min(m * k, n);
fo(i, 0, lim)
ans = (ans + (ll)b(n - i) * a[i] % mod * inv) % mod;
printf("%d\n", ans);
return 0;
}