LOJ6059「2017 山东一轮集训 Day1」Sum
倍增 dp
考虑看到 \(m,p\) 很小,dp 的话转移 \(n\) 遍但是 \(n\) 很大,于是想矩阵快速幂但发现并不太行
具体的,设 \(f(i,u,k)\) 表示考虑到第 \(i\) 位,当前数字 \(\mod p=u\),当前数位和为 \(m\)
考虑转移,有:\(f(i+1,(10u+v)\mod p,k+v)=\sum_{v=0}^9 f(i,u,k)\),这一遍是 \(O(10pm)\) 的
不能矩乘就考虑倍增 dp,从 \(f(i)\) 直接转移到 \(f(2i)\),那么式子是:
\[f(2i,(u\times 10^i+v)\mod p,k)=\sum_{x=0}^k f(i,u,x)f(i,v,k-x)
\]
发现这样直接转移是 \(O(p^2m^2)\) 的,但 \(\sum\) 里是卷积的形式,直接给 \(f(i,u)\) 都 ntt 一波,就变成每次转移 \(O(p^2m)\) 了
倍增转移一共 \(O(\log n)\) 次,所以总复杂度是 \(O(p^2m\log n)\) 的
#define G 114514
#define mod 998244353
#define N 6006
long long power(long long a,long long b,long long o=mod){
long long ans=1;
while(b){
if(b&1) ans=ans*a%o;
a=a*a%o;b>>=1;
}
return ans;
}
inline void add(long long &a,long long b){a=(a+b>=mod)?(a+b-mod):(a+b);}
inline long long Mod(long long a){return a>=mod?(a-mod):a;}
int rev[N];
inline int init(int n){
int max=1;while(max<=n) max<<=1;
for(int i=0;i<max;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(max>>1):0;
return max;
}
inline void ntt(int n,long long *a,int type){
for(int i=0;i<n;i++)if(rev[i]<i) std::swap(a[i],a[rev[i]]);
for(int h=1;h<n;h<<=1){
long long gn=power(G,(mod-1)/(h<<1)),g,o;
gn=type?gn:power(gn,mod-2);
for(int i=0;i<n;i+=h<<1){
g=1;
for(int j=i;j<i+h;j++,g=g*gn%mod){
o=g*a[j+h]%mod;
a[j+h]=Mod(a[j]-o+mod);add(a[j],o);
}
}
}
if(!type){
long long inv=power(n,mod-2);
for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
}
}
long long f[55][N],g[55][N];
inline void work(int n,int p,int m){
int len=init(m*2+2);
int pos=30;
while(!(n&(1<<pos))) pos--;
f[0][0]=1;
int i=0;
for(;~pos;pos--){
if(!i) goto ADD;
for(int u=0;u<p;u++) ntt(len,f[u],1);
for(int u=0;u<p;u++)for(int v=0;v<p;v++){
long long pp=(u*power(10,i,p)+v)%p;
for(int x=0;x<len;x++) add(g[pp][x],f[u][x]*f[v][x]%mod);
// for(int x=0;x<=m;x++)for(int k=0;k<=x;k++)
// add(g[(u*pp+v)%p][x],f[u][k]*f[v][x-k]%mod);
}
for(int u=0;u<p;u++){
ntt(len,g[u],0);
for(int x=m+1;x<len;x++) g[u][x]=0;
}
std::memcpy(f,g,sizeof f);std::memset(g,0,sizeof g);
i<<=1;
ADD: if(n&(1<<pos)){
for(int u=0;u<p;u++)for(int k=0;k<=m;k++){
for(int v=0;v<10&&k+v<=m;v++) add(g[(10*u+v)%p][k+v],f[u][k]);
}
std::memcpy(f,g,sizeof f);std::memset(g,0,sizeof g);
i++;
}
}
assert(i==n);
}
int main(){
int n=read(),p=read(),m=read();
work(n,p,m);
printf("%lld ",f[0][0]);
for(int i=1;i<=m;i++) add(f[0][i],f[0][i-1]),printf("%lld ",f[0][i]);
return 0;
}