【BZOJ 4818】 4818: [Sdoi2017]序列计数 (矩阵乘法、容斥计数)
4818: [Sdoi2017]序列计数
Time Limit: 30 Sec Memory Limit: 128 MB
Submit: 560 Solved: 359Description
Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。Input
一行三个数,n,m,p。1<=n<=10^9,1<=m<=2×10^7,1<=p<=100Output
一行一个数,满足Alice的要求的序列数量,答案对20170408取模。Sample Input
3 5 3Sample Output
33HINT
Source
【分析】
f[i][j]表示填i个数,mod为j的方案数。
这个转移可以用矩阵加速啦,那么n很大也没关系了。
然后要至少有一个质数,就用总方案数-没有一个质数。
然后一开始跑20s,看到别人2s,然后觉得这个循环矩阵的矩阵乘法可以O(n^2),改了就2s了。。
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cstring> 4 #include<iostream> 5 #include<algorithm> 6 using namespace std; 7 #define LL long long 8 #define Maxn 20000010 9 #define Mod 20170408 10 11 int pri[Maxn],pl; 12 bool vis[Maxn]; 13 14 void init(int n) 15 { 16 for(int i=1;i<=n;i++) vis[i]=0; 17 pl=0;vis[1]=1; 18 for(int i=2;i<=n;i++) 19 { 20 if(!vis[i]) pri[++pl]=i; 21 for(int j=1;j<=pl;j++) 22 { 23 if(pri[j]*i>n) break; 24 vis[i*pri[j]]=1; 25 if(i%pri[j]==0) break; 26 } 27 } 28 } 29 30 int sm[110],p; 31 32 struct Matrix 33 { 34 int w[110][110]; 35 Matrix() {memset(w,0,sizeof(w));} 36 friend Matrix operator * (Matrix x,Matrix y) 37 { 38 Matrix ret; 39 for(int k=0;k<p;k++) 40 { 41 // for(int i=0;i<p;i++) 42 for(int j=0;j<p;j++) 43 { 44 ret.w[0][j]=(ret.w[0][j]+1LL*x.w[0][k]*y.w[k][j]%Mod)%Mod; 45 } 46 } 47 for(int i=1;i<p;i++) 48 { 49 for(int j=0;j<p;j++) ret.w[i][(j+i)%p]=ret.w[0][j]; 50 } 51 return ret; 52 } 53 friend Matrix operator ^ (Matrix x,int b) 54 { 55 Matrix ret; 56 for(int i=0;i<p;i++) ret.w[i][i]=1; 57 while(b) 58 { 59 if(b&1) ret=ret*x; 60 x=x*x; 61 b>>=1; 62 } 63 return ret; 64 } 65 }; 66 67 int main() 68 { 69 int n,m; 70 scanf("%d%d%d",&n,&m,&p); 71 init(m); 72 for(int i=0;i<p;i++) sm[p]=0; 73 Matrix nw; 74 for(int i=1;i<=m;i++) sm[i%p]++; 75 for(int i=0;i<p;i++) for(int j=0;j<p;j++) (nw.w[i][(i+j)%p]+=sm[j])%Mod; 76 nw=nw^n; 77 int ans=nw.w[0][0]; 78 79 for(int i=1;i<=m;i++) if(!vis[i]) sm[i%p]--; 80 memset(nw.w,0,sizeof(nw.w)); 81 for(int i=0;i<p;i++) for(int j=0;j<p;j++) (nw.w[i][(i+j)%p]+=sm[j])%Mod; 82 nw=nw^n; 83 ans=((ans-nw.w[0][0])%Mod+Mod)%Mod; 84 85 printf("%d\n",ans); 86 return 0; 87 }
2017-04-28 10:51:39