[Sdoi2017]序列计数
4818: [Sdoi2017]序列计数
Time Limit: 30 Sec Memory Limit: 128 MBSubmit: 317 Solved: 210
Description
Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望
,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。
Input
一行三个数,n,m,p。
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100
Output
一行一个数,满足Alice的要求的序列数量,答案对20170408取模。
Sample Input
3 5 3
Sample Output
33
先是容斥,总答案=所有的方案数-只有合数的方案数。
可以推出暴力的DP方程:
f[i][j]代表长度为i,mod p为j的方案数
f[i][(j+k)%p]+=f[i-1][j]。
然后这样的复杂度是O(n*m*p),20分。
考虑优化,发现对于每个i的转移都是一样的,所以可以用矩阵快速幂。
设A为转移矩阵,发现每一个j都可以转移到j+k这个位置,A[j][j+k]+1。
这样暴力构矩阵复杂度为O(m*p),80分。
还可以优化,发现有很多地方可以记忆化,所以可以先预处理出1-m每个数中 p的每个剩余系的数量,然后直接加进去即可。
复杂度O(m)。100分。
1 #include <algorithm> 2 #include <iostream> 3 #include <cstdlib> 4 #include <cstring> 5 #include <cstdio> 6 #include <cmath> 7 #define maxn 20000010 8 #define mod 20170408 9 #define LL long long 10 using namespace std; 11 LL a[110][110][2],b[110][110][2],s[110][110]; 12 int su[maxn/10],he[maxn],s1[110],s2[110]; 13 LL n,m,p; 14 bool bj[maxn]; 15 void mul(int a1,int b1){ 16 for(int i=0;i<p;i++) 17 for(int j=0;j<p;j++) 18 for(int k=0;k<p;k++) 19 s[i][j]+=a[i][k][a1]*a[k][j][b1],s[i][j]%=mod; 20 for(int i=0;i<p;i++) 21 for(int j=0;j<p;j++) 22 a[i][j][a1]=s[i][j],s[i][j]=0; 23 } 24 void mul1(int a1,int b1){ 25 for(int i=0;i<p;i++) 26 for(int j=0;j<p;j++) 27 for(int k=0;k<p;k++) 28 s[i][j]+=b[i][k][a1]*b[k][j][b1],s[i][j]%=mod; 29 for(int i=0;i<p;i++) 30 for(int j=0;j<p;j++) 31 b[i][j][a1]=s[i][j],s[i][j]=0; 32 } 33 int main() 34 { 35 freopen("count.in","r",stdin); 36 freopen("count.out","w",stdout); 37 LL tot=0,tot1=0; 38 scanf("%lld%lld%lld",&n,&m,&p); 39 for(int i=1;i<=m;i++) a[0][i%p][0]++; 40 for(int i=1;i<=m;i++) s1[i%p]++; 41 for(int i=0;i<p;i++) 42 for(int j=0;j<p;j++) 43 a[j][(j+i)%p][1]+=s1[i]; 44 int mi=n-1; 45 while(mi){ 46 if(mi%2) mul(0,1); 47 mi>>=1; 48 mul(1,1); 49 } 50 bj[1]=1; 51 for(int i=2;i<=m;i++){ 52 if(!bj[i]) su[++tot]=i; 53 for(int j=1;j<=tot;j++){ 54 if(su[j]*i>m) break; 55 bj[su[j]*i]=1; 56 if(i%su[j]==0) break; 57 } 58 } 59 for(int i=1;i<=m;i++) if(bj[i])b[0][i%p][0]++; 60 for(int i=1;i<=m;i++) if(bj[i])s2[i%p]++; 61 for(int i=0;i<p;i++) 62 for(int j=0;j<p;j++) 63 b[j][(j+i)%p][1]+=s2[i]; 64 mi=n-1; 65 while(mi){ 66 if(mi%2) mul1(0,1); 67 mi>>=1; 68 mul1(1,1); 69 } 70 printf("%lld",(a[0][0][0]-b[0][0][0]+mod)%mod); 71 return 0; 72 }