LOJ#6059. 「2017 山东一轮集训 Day1」Sum 题解

题目链接

考虑记 \(F_{n,m,r}\) 表示有多少个 \(n\) 位数 , 满足各位数之和为 \(m\) ,且对 \(p\) 取模的结果为 \(r\) .

不难发现我们只需要支持 \(F_{n} -> F_{n+1}\)\(F_{n} -> F_{2n}\) 就可以了。

\(F_{n} -> F_{n+1}\) 可以直接枚举下一位数是什么,暴力转移,单次复杂度 \(\Theta(10pm)\)

\(F_{n} -> F_{2n}\) 需要卷积,实现的时候注意不要用过多的 NTT , 3p 次就够了 , 单次复杂度 \(\Theta(p^2m+pm\log m)\)

总复杂度 \(\Theta (pm(10+p+\log m)\log n)\)

code :

#include <bits/stdc++.h>
#define uint unsigned
#define LL unsigned long long
using namespace std;
const uint P = 998244353,N = 2048;
inline uint power(uint x,int y){
	static uint r; r = 1;
	while (y){ if (y&1) r = (LL)r * x % P; x = (LL)x * x % P,y >>= 1; } return r;
	return r;
}
uint wn[N<<1],iwn[N<<1],R[N],L,iL;
inline int getR(int n){
	int i,l,L; l = 1,L = 2; while (L <= n) L <<= 1,++l; 
	for (i = 1; i < L; ++i) R[i] = (R[i>>1]>>1) | ((i&1)<<l-1);
	return L;
}
inline void NTT(uint *A){
	register int i,j,k; uint v;
	for (i = 1; i < L; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
	for (i = 1; i < L; i <<= 1) for (j = 0; j < L; j += i << 1) for (k = j; k < i+j; ++k)
		v = (LL)A[k+i] * wn[(i<<1)+k-j] % P,A[k+i] = (A[k]<v)?(A[k]+P-v):(A[k]-v),A[k] = (A[k]+v>=P)?(A[k]+v-P):(A[k]+v);
}
inline void iNTT(uint *A){
	register int i,j,k; uint v;
	for (i = 1; i < L; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
	for (i = 1; i < L; i <<= 1) for (j = 0; j < L; j += i << 1) for (k = j; k < i+j; ++k)
		v = (LL)A[k+i] * iwn[(i<<1)+k-j] % P,A[k+i] = (A[k]<v)?(A[k]+P-v):(A[k]-v),A[k] = (A[k]+v>=P)?(A[k]+v-P):(A[k]+v);
	for (i = 0; i < L; ++i) A[i] = (LL)A[i] * iL % P;
}
int n,m,p;
uint tmp[50][N];
inline void upd(uint &x,uint v){ x = (x+v>=P)?(x+v-P):(x+v); }
struct data{
	uint F[50][N],w;
	inline void init0(){ w = F[0][0] = 1; }
	inline void init1(){ w = 10 % p; for (int i = 0; i <= 9 && i <= m; ++i) ++F[i%p][i]; }
	inline void modify(int w0){
		register int i,j;
		w = w * w0 % p;
		for (i = 0; i < p; ++i) memset(tmp[i],0,m+1<<2);
		for (i = 0; i < p; ++i) for (j = 0; j <= m; ++j) upd(tmp[i*w0%p][j],F[i][j]);
		for (i = 0; i < p; ++i) memcpy(F[i],tmp[i],m+1<<2);
	}
	inline void extend(){
		register int c,i,j;
		modify(10);
		for (i = 0; i < p; ++i) memcpy(tmp[i],F[i],m+1<<2),memset(F[i],0,m+1<<2);
		for (i = 0; i < p; ++i) for (j = 0; j <= m; ++j) if (tmp[i][j])
		for (c = 0; c <= 9 && j+c <= m; ++c) upd(F[(i+c)%p][j+c],tmp[i][j]);
	}
}A,B,C; 
inline void merge(data &A,data &B,data &C){ // A.F = B.F * C.F 
	register int i,j,t,k;
	for (i = 0; i < p; ++i) NTT(B.F[i]),NTT(C.F[i]),memset(A.F[i],0,L<<2);
	for (i = 0; i < p; ++i) for (j = 0; j < p; ++j)
	for (t = (i+j)%p,k = 0; k < L; ++k) upd(A.F[t][k],(LL)B.F[i][k] * C.F[j][k] % P);
	for (i = 0; i < p; ++i) iNTT(A.F[i]),memset(A.F[i]+m+1,0,L-m-1<<2);
}
inline void work(int n){
	if (n == 1){ A.init1(); return; }
	work(n/2);
	memcpy(B.F[0],A.F[0],sizeof(B)),memcpy(C.F[0],A.F[0],sizeof(C));
	C.modify(A.w),merge(A,B,C),A.w = C.w; 
	if (n&1) A.extend();
}
int main(){
	int i,j; uint v;
	cin >> n >> p >> m;
	L = getR(m<<1),iL = power(L,P-2);
	for (i = 2; i <= L; i <<= 1){
		int l = i+(i>>1);
		v = power(3,(P-1)/i),wn[i] = 1;
		for (j = i+1; j < l; ++j) wn[j] = (LL)wn[j-1] * v % P;
		v = power(v,P-2),iwn[i] = 1;
		for (j = i+1; j < l; ++j) iwn[j] = (LL)iwn[j-1] * v % P;
	}
	work(n);
	for (v = i = 0; i <= m; ++i) upd(v,A.F[0][i]),cout << v << (i<m?' ':'\n');
	return 0;
}
posted @ 2020-09-26 13:49  srf  阅读(173)  评论(0编辑  收藏  举报