题解 [2017 山东一轮集训 Day1] Sum

传送门

好题!

首先看这个题像矩阵快速幂
但是有这个和为 \(m\) 的限制很难搞出转移矩阵来
于是……

  • 看起来像矩阵快速幂但无法设计转移矩阵/状态中带有求和/DP 转移支持拼接的问题:
    考虑能不能倍增处理,尝试使用 \(f_i\) 得到 \(f_{2i}\)

于是对于本题:
暴力 DP 是令 \(f_{i, u, k}\) 为到第 \(i\) 位,\(\bmod p=j\),数位和为 \(k\) 的方案数
转移枚举当前位选什么

\[f_{i, (10u+v)\bmod p, k}=\sum\limits_{v=0}^9f_{i-1, u, k-v} \]

然后倍增的转移是类似的(注意这里 DP 定义已经换了)

\[f_{i, (10^{2^i}u+v)\bmod p, k}=\sum\limits_{x=0}^k f_{i-1, u, x}f_{i-1, v, k-x} \]

直接这样转移一次是 \(O(p^2m^2)\)
但是发现后面部分是个卷积,所以直接将 \(f\) DFT 了再做就可以单次 \(O(p^2m)\)
那么总复杂度就是 \(O(p^2m\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, p, m;
int rev[N], now, bln, bct;
const ll mod=998244353, rt=3, phi=mod-1;
ll f[32][50][2050], g[2][50][2050], pw[35];
inline ll qpow(ll a, ll b, ll mod=::mod) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

void ntt(ll* a, int len, int op) {
	for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
	ll w, wn, t;
	for (int i=1; i<len; i<<=1) {
		wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
		for (int j=0,step=i<<1; j<len; j+=step) {
			w=1;
			for (int k=j; k<j+i; ++k,w=w*wn%mod) {
				t=w*a[k+i]%mod;
				a[k+i]=(a[k]-t)%mod;
				a[k]=(a[k]+t)%mod;
			}
		}
	}
	if (op==-1) {
		ll inv=qpow(len, mod-2);
		for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
	}
}

signed main()
{
	n=read(); p=read(); m=read();
	for (int i=0; i<32; ++i) pw[i]=qpow(10, 1ll<<i, p);
	for (bln=1; bln<=2*m; bln<<=1,++bct);
	for (int i=0; i<bln; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
	for (int i=0; i<=min(9, m); ++i) f[0][i%p][i]+=1;
	for (int i=0; i<p; ++i) ntt(f[0][i], bln, 1);
	for (int i=1; i<32; ++i) {
		for (int u=0; u<p; ++u)
			for (int v=0; v<p; ++v)
				for (int k=0; k<bln; ++k)
					f[i][(u*pw[i-1]+v)%p][k]=(f[i][(u*pw[i-1]+v)%p][k]+f[i-1][u][k]*f[i-1][v][k])%mod;
		for (int j=0; j<p; ++j) {
			ntt(f[i][j], bln, -1);
			for (int k=m+1; k<bln; ++k) f[i][j][k]=0;
			ntt(f[i][j], bln, 1);
		}
	}
	g[now][0][0]=1;
	for (int i=0; i<32; ++i) if (n&(1ll<<i)) {
		memset(g[now^1], 0, sizeof(g[now^1]));
		for (int j=0; j<p; ++j) ntt(g[now][j], bln, 1);
		for (int u=0; u<p; ++u)
			for (int v=0; v<p; ++v)
				for (int k=0; k<bln; ++k)
					g[now^1][(u+v*pw[i])%p][k]=(g[now^1][(u+v*pw[i])%p][k]+f[i][u][k]*g[now][v][k])%mod;
		for (int j=0; j<p; ++j) {
			ntt(g[now^1][j], bln, -1);
			for (int k=m+1; k<bln; ++k) g[now^1][j][k]=0;
		}
		now^=1;
	}
	for (int i=1; i<=m; ++i) g[now][0][i]=(g[now][0][i-1]+g[now][0][i])%mod;
	for (int i=0; i<=m; ++i) printf("%lld%c", (g[now][0][i]%mod+mod)%mod, " \n"[i==m]);

	return 0;
}
posted @ 2022-06-21 20:59  Administrator-09  阅读(23)  评论(0编辑  收藏  举报