BZOJ 4818: [Sdoi2017]序列计数
感觉是sdoi2015序列统计的弱化版...
至少用一个质数就用全集减去只用合数解决
\(dp_{i, j}\) 表示已经放了 \(i\) 个数,当前和为 \(j \pmod p\)
转移方程 \(dp_{i, j} = \sum \limits_{x+y=j}dp_{i-1,x} \times c_y\)
看成生成函数的形式就是 \(C^n\),快速幂加暴力卷积即可...连FFT都不用敲
#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define pii pair<ll, int>
#define lp p << 1
#define rp p << 1 | 1
#define mid ((l + r) >> 1)
#define ll long long
#define db double
#define rep(i,a,b) for(int i=a;i<b;i++)
#define per(i,a,b) for(int i=b-1;i>=a;i--)
#define Edg int ccnt=1,head[N],to[N*2],ne[N*2];void addd(int u,int v){to[++ccnt]=v;ne[ccnt]=head[u];head[u]=ccnt;}void add(int u,int v){addd(u,v);addd(v,u);}
#define Edgc int ccnt=1,head[N],to[N*2],ne[N*2],c[N*2];void addd(int u,int v,int w){to[++ccnt]=v;ne[ccnt]=head[u];c[ccnt]=w;head[u]=ccnt;}void add(int u,int v,int w){addd(u,v,w);addd(v,u,w);}
#define es(u,i,v) for(int i=head[u],v=to[i];i;i=ne[i],v=to[i])
const int MOD = 20170408;
void M(int &x) {if (x >= MOD)x -= MOD; if (x < 0)x += MOD;}
int qp(int a, int b = MOD - 2) {int ans = 1; for (; b; a = 1LL * a * a % MOD, b >>= 1)if (b & 1)ans = 1LL * ans * a % MOD; return ans % MOD;}
int gcd(int a, int b) { while (b) { a %= b; std::swap(a, b); } return a; }
const int N = 2e7 + 7;
int n, m, p, prime[N / 10], prin;
void init() {
static bool vis[N];
rep (i, 2, m + 1) {
if (!vis[i]) prime[++prin] = i;
rep (j, 1, prin + 1) {
if (1LL * i * prime[j] > m) break;
vis[i * prime[j]] = 1;
if (i % prime[j] == 0) break;
}
}
}
void mul(int *a, int *b, int n) {
static ll c[222];
memset(c, 0, sizeof(c));
rep (i, 0, n) {
rep(j, 0, n) {
c[i + j] += 1LL * a[i] * b[j];
}
}
rep (i, 0, n) M(a[i] = (c[i] + c[i + n]) % MOD);
}
void qp(int *a, int *b, int n, int k) {
b[0] = 1;
while (k) {
if (k & 1) mul(b, a, n);
mul(a, a, n);
k >>= 1;
}
}
int a[333], b[333];
int main() {
scanf("%d%d%d", &n, &m, &p);
init();
rep (i, 1, m + 1) a[i % p]++;
qp(a, b, p, n);
memset(a, 0, sizeof(a));
rep (i, 1, m + 1) a[i % p]++;
rep (i, 1, prin + 1) a[prime[i] % p]--;
int ans = b[0];
memset(b, 0, sizeof(b));
qp(a, b, p, n);
M(ans = ans - b[0]);
printf("%d\n", ans);
}