题解 [2017 山东一轮集训 Day1] Sum
好题!
首先看这个题像矩阵快速幂
但是有这个和为 \(m\) 的限制很难搞出转移矩阵来
于是……
- 看起来像矩阵快速幂但无法设计转移矩阵/状态中带有求和/DP 转移支持拼接的问题:
考虑能不能倍增处理,尝试使用 \(f_i\) 得到 \(f_{2i}\)
于是对于本题:
暴力 DP 是令 \(f_{i, u, k}\) 为到第 \(i\) 位,\(\bmod p=j\),数位和为 \(k\) 的方案数
转移枚举当前位选什么
\[f_{i, (10u+v)\bmod p, k}=\sum\limits_{v=0}^9f_{i-1, u, k-v}
\]
然后倍增的转移是类似的(注意这里 DP 定义已经换了)
\[f_{i, (10^{2^i}u+v)\bmod p, k}=\sum\limits_{x=0}^k f_{i-1, u, x}f_{i-1, v, k-x}
\]
直接这样转移一次是 \(O(p^2m^2)\) 的
但是发现后面部分是个卷积,所以直接将 \(f\) DFT 了再做就可以单次 \(O(p^2m)\) 了
那么总复杂度就是 \(O(p^2m\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, p, m;
int rev[N], now, bln, bct;
const ll mod=998244353, rt=3, phi=mod-1;
ll f[32][50][2050], g[2][50][2050], pw[35];
inline ll qpow(ll a, ll b, ll mod=::mod) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
void ntt(ll* a, int len, int op) {
for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
ll w, wn, t;
for (int i=1; i<len; i<<=1) {
wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
for (int j=0,step=i<<1; j<len; j+=step) {
w=1;
for (int k=j; k<j+i; ++k,w=w*wn%mod) {
t=w*a[k+i]%mod;
a[k+i]=(a[k]-t)%mod;
a[k]=(a[k]+t)%mod;
}
}
}
if (op==-1) {
ll inv=qpow(len, mod-2);
for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
}
}
signed main()
{
n=read(); p=read(); m=read();
for (int i=0; i<32; ++i) pw[i]=qpow(10, 1ll<<i, p);
for (bln=1; bln<=2*m; bln<<=1,++bct);
for (int i=0; i<bln; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
for (int i=0; i<=min(9, m); ++i) f[0][i%p][i]+=1;
for (int i=0; i<p; ++i) ntt(f[0][i], bln, 1);
for (int i=1; i<32; ++i) {
for (int u=0; u<p; ++u)
for (int v=0; v<p; ++v)
for (int k=0; k<bln; ++k)
f[i][(u*pw[i-1]+v)%p][k]=(f[i][(u*pw[i-1]+v)%p][k]+f[i-1][u][k]*f[i-1][v][k])%mod;
for (int j=0; j<p; ++j) {
ntt(f[i][j], bln, -1);
for (int k=m+1; k<bln; ++k) f[i][j][k]=0;
ntt(f[i][j], bln, 1);
}
}
g[now][0][0]=1;
for (int i=0; i<32; ++i) if (n&(1ll<<i)) {
memset(g[now^1], 0, sizeof(g[now^1]));
for (int j=0; j<p; ++j) ntt(g[now][j], bln, 1);
for (int u=0; u<p; ++u)
for (int v=0; v<p; ++v)
for (int k=0; k<bln; ++k)
g[now^1][(u+v*pw[i])%p][k]=(g[now^1][(u+v*pw[i])%p][k]+f[i][u][k]*g[now][v][k])%mod;
for (int j=0; j<p; ++j) {
ntt(g[now^1][j], bln, -1);
for (int k=m+1; k<bln; ++k) g[now^1][j][k]=0;
}
now^=1;
}
for (int i=1; i<=m; ++i) g[now][0][i]=(g[now][0][i-1]+g[now][0][i])%mod;
for (int i=0; i<=m; ++i) printf("%lld%c", (g[now][0][i]%mod+mod)%mod, " \n"[i==m]);
return 0;
}