【BZOJ4818】【SDOI2017】序列计数 [矩阵乘法][DP]
序列计数
Time Limit: 30 Sec Memory Limit: 128 MB[Submit][Status][Discuss]
Description
Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。
Input
一行三个数,n,m,p。
Output
一行一个数,满足Alice的要求的序列数量,答案对20170408取模。
Sample Input
3 5 3
Sample Output
33
HINT
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100
Solution
先考虑容斥,用Ans=全部的方案数 - 一个质数都没有的方案,那么我们首先想到了一个暴力DP,令 f[i][j] 表示选了前 i 个数,%p时余数为 j 的方案数。那么显然 %p 同余的可以分为一类,那么就可以用矩阵乘法来优化这个DP了。
Code
1 #include<iostream>
2 #include<string>
3 #include<algorithm>
4 #include<cstdio>
5 #include<cstring>
6 #include<cstdlib>
7 #include<cmath>
8 using namespace std;
9 typedef long long s64;
10
11 const int MaxM = 2e7+5;
12 const int ONE = 105;
13 const int MOD = 20170408;
14
15 int n,m,p;
16 int prime[1300005],p_num;
17 int Record[ONE][2],a[ONE][ONE],b[ONE][ONE];
18 bool isp[MaxM];
19
20 inline int get()
21 {
22 int res=1,Q=1; char c;
23 while( (c=getchar())<48 || c>57)
24 if(c=='-')Q=-1;
25 if(Q) res=c-48;
26 while((c=getchar())>=48 && c<=57)
27 res=res*10+c-48;
28 return res*Q;
29 }
30
31 void Getp(int MaxN)
32 {
33 isp[1] = 1;
34 for(int i=2; i<=MaxN; i++)
35 {
36 if(!isp[i])
37 prime[++p_num] = i;
38 for(int j=1; j<=p_num, i*prime[j]<=MaxN; j++)
39 {
40 isp[i * prime[j]] = 1;
41 if(i % prime[j] == 0) break;
42 }
43 }
44 }
45
46 void Mul(int a[ONE][ONE],int b[ONE][ONE],int ans[ONE][ONE])
47 {
48 int record[ONE][ONE];
49 for(int i=0;i<p;i++)
50 for(int j=0;j<p;j++)
51 {
52 record[i][j] = 0;
53 for(int k=0;k<p;k++)
54 record[i][j] = (s64)(record[i][j] + (s64)a[i][k]*b[k][j] % MOD) %MOD;
55 }
56
57 for(int i=0;i<p;i++)
58 for(int j=0;j<p;j++)
59 ans[i][j] = record[i][j];
60 }
61
62 void Quickpow(int a[ONE][ONE],int b[ONE][ONE],int t)
63 {
64 while(t)
65 {
66 if(t&1) Mul(a,b,a);
67 Mul(b,b,b);
68 t>>=1;
69 }
70 }
71
72 int Solve(int PD)
73 {
74 memset(a,0,sizeof(a));
75 memset(b,0,sizeof(b));
76
77 for(int i=0;i<p;i++)
78 for(int j=0;j<p;j++)
79 b[i][j] = Record[((i-j)%p+p)%p][PD];
80
81 for(int i=0;i<p;i++)
82 a[i][i] = 1;
83
84 Quickpow(a,b,n);
85 return a[0][0];
86 }
87
88 int main()
89 {
90 n=get(); m=get(); p=get(); Getp(m);
91 for(int i=1;i<=m;i++)
92 {
93 int x = i%p;
94 Record[x][0]++;
95 if(isp[i]) Record[x][1]++;
96 }
97
98 printf("%d",(Solve(0)-Solve(1)+MOD) % MOD);
99 }