GMOJ7336. 整数的拆分

题目大意

给出 \(n,m,k\),定义 \(f(z)=\sum\limits_{i\ge0}i^mz^i\)。求 \([z^n]f^k(z)\)

\(n\le 10^7, m,k\le 10^3\).

解题思路

对于 \(i^m\),大致处理思路有:1. 斯特林数反演。2. \(m + 1\) 次差分后变成一个 \(\deg \le m\)的多项式。

这道题我采用了方法2。

大致思路:先差分,然后做 NTT ,最后前缀和回来。

首先求出 \(f\) 的前 \(m\) 项,暴力求差分即可。这一部分是 \(O(m^2)\) 的。设差分后的函数为 \(g\)

然后求出 \(g^k\)。这个是 \(O(mk\log mk)\)的。

然后考虑卷上一个 \(\frac 1 {(1-z)^{(m+1)k}}\) ,将差分前缀和回来。注意到我们只需要一个点值,直接 \(O(\min(n,mk))\) 计算即可。

总时间复杂度是 \(O(m^2+mk+\min(n,mk))\) 的。

#include <cstdio>
#include <cstring>
#include <algorithm>
#pragma GCC optimize ("Ofast")
#define ll long long
#define fo(i, a, b) for(int i = (a); i <= (b); ++i)
#define fd(i, a, b) for(int i = (a); i >= (b); --i)
using namespace std;
const int N = 1e7, M = 2.5e6;
const int mod = 998244353, G = 3, Gi = 332748118;
int n, m, k, inv, fac[N + 10] = {1}, ifc[N + 10];
int lim, L, a[M + 10], r[M + 10];
inline int qpow(int a, int b) {int ret = 1; for(; b; b >>= 1, a = (ll)a * a % mod)	b & 1 && (ret = (ll)ret * a % mod); return ret;}
inline void reset(int n) {
	lim = 1, L = 0;
	while(lim <= n)	lim <<= 1, ++L;
	inv = qpow(lim, mod - 2);
	for(int i = 0; i < lim; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
inline void NTT(int *A, int op) {
	static int W[N] = {1};
	for(int i = 0; i < lim; ++i)
		if(i < r[i])	swap(A[i], A[r[i]]);
	for(int mid = 1; mid < lim; mid <<= 1) {
		int Wn = qpow(op == 1 ? G : Gi, (mod - 1) / (mid << 1));
		for(int i = 1; i < mid; ++i)	W[i] = (ll)W[i - 1] * Wn % mod;
		for(int j = 0; j < lim; j += (mid << 1))
			for(int k = 0; k < mid; ++k) {
				int x = A[j + k], y = (ll)W[k] * A[j + k + mid] % mod;
				A[j + k] = (x + y) % mod;
				A[j + k + mid] = (x - y + mod) % mod;
			}
	}
}
inline int C(int n, int m) {return n < 0 || n > m ? 0 : (ll)fac[m] * ifc[m - n] % mod * ifc[n] % mod;}
inline int b(int t) {return C(t, t + (m + 1) * k - 1);}
int main() {
	freopen("split.in", "r", stdin);
	freopen("split.out", "w", stdout);
	scanf("%d %d %d", &n, &m, &k);
	fo(i, 1, N)	fac[i] = (ll)fac[i - 1] * i % mod;
	ifc[N] = qpow(fac[N], mod - 2);
	fd(i, N, 1)	ifc[i - 1] = (ll)ifc[i] * i % mod;
	
	fo(i, 0, m)	a[i] = qpow(i, m);
	fo(i, 0, m)
		fd(j, m, 1)	a[j] = (a[j] - a[j - 1] + mod) % mod;
	reset(m * k);
	
	NTT(a, 1);
	for(int i = 0; i < lim; ++i)	a[i] = qpow(a[i], k);
	NTT(a, -1);

	int ans = 0, lim = min(m * k, n);
	fo(i, 0, lim)
		ans = (ans + (ll)b(n - i) * a[i] % mod * inv) % mod;


	printf("%d\n", ans);
	
	return 0;
}
posted @ 2021-11-01 22:12  Martin_MHT  阅读(20)  评论(0)    收藏  举报