【BZOJ4818】【SDOI2017】序列计数 [矩阵乘法][DP]

序列计数

Time Limit: 30 Sec  Memory Limit: 128 MB
[Submit][Status][Discuss]

Description

  Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。

Input

  一行三个数,n,m,p。

Output

  一行一个数,满足Alice的要求的序列数量,答案对20170408取模。

Sample Input

  3 5 3

Sample Output

  33

HINT

  1<=n<=10^9,1<=m<=2×10^7,1<=p<=100

Solution

  先考虑容斥,用Ans=全部的方案数 - 一个质数都没有的方案,那么我们首先想到了一个暴力DP,令 f[i][j] 表示选了前 i 个数,%p时余数为 j 的方案数。那么显然 %p 同余的可以分为一类,那么就可以用矩阵乘法来优化这个DP了。

Code

 1 #include<iostream>  
 2 #include<string>  
 3 #include<algorithm>  
 4 #include<cstdio>  
 5 #include<cstring>  
 6 #include<cstdlib>  
 7 #include<cmath>
 8 using namespace std; 
 9 typedef long long s64;
10   
11 const int MaxM = 2e7+5;
12 const int ONE = 105;
13 const int MOD = 20170408;
14  
15 int n,m,p;
16 int prime[1300005],p_num;
17 int Record[ONE][2],a[ONE][ONE],b[ONE][ONE];
18 bool isp[MaxM];
19  
20 inline int get() 
21 {
22         int res=1,Q=1;  char c;
23         while( (c=getchar())<48 || c>57)
24         if(c=='-')Q=-1;
25         if(Q) res=c-48; 
26         while((c=getchar())>=48 && c<=57) 
27         res=res*10+c-48; 
28         return res*Q; 
29 }
30  
31 void Getp(int MaxN)
32 {
33         isp[1] = 1;
34         for(int i=2; i<=MaxN; i++)
35         {
36             if(!isp[i])
37                 prime[++p_num] = i;
38             for(int j=1; j<=p_num, i*prime[j]<=MaxN; j++)
39             {
40                 isp[i * prime[j]] = 1;
41                 if(i % prime[j] == 0) break;
42             }
43         }
44 }
45  
46 void Mul(int a[ONE][ONE],int b[ONE][ONE],int ans[ONE][ONE])
47 {
48         int record[ONE][ONE]; 
49         for(int i=0;i<p;i++)
50         for(int j=0;j<p;j++)
51         {
52             record[i][j] = 0;
53             for(int k=0;k<p;k++)
54             record[i][j] = (s64)(record[i][j] + (s64)a[i][k]*b[k][j] % MOD) %MOD;
55         }
56          
57         for(int i=0;i<p;i++)
58         for(int j=0;j<p;j++)
59             ans[i][j] = record[i][j];
60 }
61  
62 void Quickpow(int a[ONE][ONE],int b[ONE][ONE],int t)
63 {
64         while(t)
65         {
66             if(t&1) Mul(a,b,a);
67             Mul(b,b,b);
68             t>>=1;
69         }
70 }
71  
72 int Solve(int PD)
73 {
74         memset(a,0,sizeof(a));
75         memset(b,0,sizeof(b));
76          
77         for(int i=0;i<p;i++)
78         for(int j=0;j<p;j++)
79             b[i][j] = Record[((i-j)%p+p)%p][PD];
80          
81         for(int i=0;i<p;i++)
82             a[i][i] = 1;
83              
84         Quickpow(a,b,n);
85         return a[0][0];
86 }
87  
88 int main()
89 {
90         n=get();    m=get();    p=get();    Getp(m);
91         for(int i=1;i<=m;i++)
92         {
93             int x = i%p;
94             Record[x][0]++;
95             if(isp[i]) Record[x][1]++;
96         }
97          
98         printf("%d",(Solve(0)-Solve(1)+MOD) % MOD);
99 }
View Code

 

posted @ 2017-04-13 17:07  BearChild  阅读(370)  评论(0编辑  收藏  举报