LOJ#6059. 「2017 山东一轮集训 Day1」Sum 题解
考虑记 \(F_{n,m,r}\) 表示有多少个 \(n\) 位数 , 满足各位数之和为 \(m\) ,且对 \(p\) 取模的结果为 \(r\) .
不难发现我们只需要支持 \(F_{n} -> F_{n+1}\) 和 \(F_{n} -> F_{2n}\) 就可以了。
\(F_{n} -> F_{n+1}\) 可以直接枚举下一位数是什么,暴力转移,单次复杂度 \(\Theta(10pm)\)
\(F_{n} -> F_{2n}\) 需要卷积,实现的时候注意不要用过多的 NTT , 3p 次就够了 , 单次复杂度 \(\Theta(p^2m+pm\log m)\)
总复杂度 \(\Theta (pm(10+p+\log m)\log n)\)
code :
#include <bits/stdc++.h>
#define uint unsigned
#define LL unsigned long long
using namespace std;
const uint P = 998244353,N = 2048;
inline uint power(uint x,int y){
static uint r; r = 1;
while (y){ if (y&1) r = (LL)r * x % P; x = (LL)x * x % P,y >>= 1; } return r;
return r;
}
uint wn[N<<1],iwn[N<<1],R[N],L,iL;
inline int getR(int n){
int i,l,L; l = 1,L = 2; while (L <= n) L <<= 1,++l;
for (i = 1; i < L; ++i) R[i] = (R[i>>1]>>1) | ((i&1)<<l-1);
return L;
}
inline void NTT(uint *A){
register int i,j,k; uint v;
for (i = 1; i < L; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
for (i = 1; i < L; i <<= 1) for (j = 0; j < L; j += i << 1) for (k = j; k < i+j; ++k)
v = (LL)A[k+i] * wn[(i<<1)+k-j] % P,A[k+i] = (A[k]<v)?(A[k]+P-v):(A[k]-v),A[k] = (A[k]+v>=P)?(A[k]+v-P):(A[k]+v);
}
inline void iNTT(uint *A){
register int i,j,k; uint v;
for (i = 1; i < L; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
for (i = 1; i < L; i <<= 1) for (j = 0; j < L; j += i << 1) for (k = j; k < i+j; ++k)
v = (LL)A[k+i] * iwn[(i<<1)+k-j] % P,A[k+i] = (A[k]<v)?(A[k]+P-v):(A[k]-v),A[k] = (A[k]+v>=P)?(A[k]+v-P):(A[k]+v);
for (i = 0; i < L; ++i) A[i] = (LL)A[i] * iL % P;
}
int n,m,p;
uint tmp[50][N];
inline void upd(uint &x,uint v){ x = (x+v>=P)?(x+v-P):(x+v); }
struct data{
uint F[50][N],w;
inline void init0(){ w = F[0][0] = 1; }
inline void init1(){ w = 10 % p; for (int i = 0; i <= 9 && i <= m; ++i) ++F[i%p][i]; }
inline void modify(int w0){
register int i,j;
w = w * w0 % p;
for (i = 0; i < p; ++i) memset(tmp[i],0,m+1<<2);
for (i = 0; i < p; ++i) for (j = 0; j <= m; ++j) upd(tmp[i*w0%p][j],F[i][j]);
for (i = 0; i < p; ++i) memcpy(F[i],tmp[i],m+1<<2);
}
inline void extend(){
register int c,i,j;
modify(10);
for (i = 0; i < p; ++i) memcpy(tmp[i],F[i],m+1<<2),memset(F[i],0,m+1<<2);
for (i = 0; i < p; ++i) for (j = 0; j <= m; ++j) if (tmp[i][j])
for (c = 0; c <= 9 && j+c <= m; ++c) upd(F[(i+c)%p][j+c],tmp[i][j]);
}
}A,B,C;
inline void merge(data &A,data &B,data &C){ // A.F = B.F * C.F
register int i,j,t,k;
for (i = 0; i < p; ++i) NTT(B.F[i]),NTT(C.F[i]),memset(A.F[i],0,L<<2);
for (i = 0; i < p; ++i) for (j = 0; j < p; ++j)
for (t = (i+j)%p,k = 0; k < L; ++k) upd(A.F[t][k],(LL)B.F[i][k] * C.F[j][k] % P);
for (i = 0; i < p; ++i) iNTT(A.F[i]),memset(A.F[i]+m+1,0,L-m-1<<2);
}
inline void work(int n){
if (n == 1){ A.init1(); return; }
work(n/2);
memcpy(B.F[0],A.F[0],sizeof(B)),memcpy(C.F[0],A.F[0],sizeof(C));
C.modify(A.w),merge(A,B,C),A.w = C.w;
if (n&1) A.extend();
}
int main(){
int i,j; uint v;
cin >> n >> p >> m;
L = getR(m<<1),iL = power(L,P-2);
for (i = 2; i <= L; i <<= 1){
int l = i+(i>>1);
v = power(3,(P-1)/i),wn[i] = 1;
for (j = i+1; j < l; ++j) wn[j] = (LL)wn[j-1] * v % P;
v = power(v,P-2),iwn[i] = 1;
for (j = i+1; j < l; ++j) iwn[j] = (LL)iwn[j-1] * v % P;
}
work(n);
for (v = i = 0; i <= m; ++i) upd(v,A.F[0][i]),cout << v << (i<m?' ':'\n');
return 0;
}