LOJ 2183 / SDOI2015 序列统计 (DP+矩阵快速幂)

题面

传送门

分析

考虑容斥原理,用总的方案数-不含质数的方案数

\(dp1[i][j]\)表示前i个数,和取模p为j的方案数,

\(dp2[i][j]\)表示前i个数,和取模p为j的方案数,且所有的数均不为质数

[1,m]中的质数可以线性筛出

\(dp1[i][j]=dp1[i-1][((j-k) \mod p+p)\mod p],j \in [0,p-1],k \in [0,m]\)

\(dp2[i][j]=dp1[i-1][((j-k) \mod p+p)\mod p],j \in [0,p-1],k \in [0,m]且不为质数\)

最终答案为\(dp1[n][0]-dp2[n][0]\)

其中k表示第i位选的数,((j-k)%p+p)%p为前i位的和,这里的减法是带模减法,是为了防止负数取模造成的问题

该算法的时间复杂度为\(O(nmp)\)

#include<iostream>
#include<cstdio>
#include<cstring>
#define maxn 105
#define mod 20170408
using namespace std;
int n,m,p; 

long long dp1[maxn][maxn],dp2[maxn][maxn];
int cnt=0;
int vis[maxn];
int prime[maxn];
void sieve(int n){
	vis[1]=1;
	for(int i=2;i<=n;i++){
		if(!vis[i]){
			prime[++cnt]=i;
		}
		for(int j=1;j<=cnt&&(long long)i*prime[j]<=(long long)n;j++){
			vis[i*prime[j]]=1;
			if(i%prime[j]==0) break;
		}
	}
}

int main(){
	scanf("%d %d %d",&n,&m,&p);
	sieve(m);
	dp1[0][0]=1;
	for(int i=1;i<=n;i++){
		for(int j=0;j<p;j++){
			for(int k=1;k<=m;k++){
				dp1[i][j]+=dp1[i-1][((j-k)%p+p)%p]; 
				dp1[i][j]%=mod;
			}
		}
	}
	dp2[0][0]=1;
	for(int i=1;i<=n;i++){
		for(int j=0;j<p;j++){
			for(int k=1;k<=m;k++){
				if(vis[k]==0) continue;
				dp2[i][j]+=dp2[i-1][((j-k)%p+p)%p];
				dp2[i][j]%=mod; 
			}
		}
	}
	printf("%lld\n",dp1[n][0]-dp2[n][0]);
}

有一个小优化,在转移的过程中我们不关心k的值,而是关心k%p的值,所以我们把[1,m]中的数按模p的余数分类

设cntm[i]表示[1,m]中的数%p余i的个数

cnth[i]表示[1,m]中的合数%p余i的个数

则上述状态转移方程可以改写为

\(dp1[i][j]=dp1[i-1][((j-k) \mod p+p)\mod p] \times cntm[k],j \in [0,p-1],k \in [0,p-1]\)

\(dp2[i][j]=dp1[i-1][((j-k) \mod p+p)\mod p] \times cnth[k] ,j \in [0,p-1],k \in [0,p-1]且不为质数\)

我们发现从i-1到i的转移是确定的,可以用矩阵快速幂优化

我们来构造转移矩阵

\(\begin{bmatrix} dp1[i][0] \\ dp1[i][1]\\ \vdots \\dp1[i][p-1]\end{bmatrix} = \begin{bmatrix} cntm[0] \ cntm[p-1] \ cntm[p-2] \ \dots \ cntm[1] \\cntm[1] \ cntm[0] \ cntm[p-1] \ \dots \ cntm[2] \\ \vdots \\ cntm[p-1] \ cntm[p-2] \ cntm[p-3] \ \dots \ cntm[0] \end{bmatrix} \times \begin{bmatrix} dp1[i-1][0] \\ dp1[i-1][1]\\ \vdots \\dp1[i-1][p-1] \end{bmatrix}\)

转移矩阵的第i行第j列为cntm[(i-j+p)%p]

同理有

\(\begin{bmatrix} dp2[i][0] \\ dp2[i][1]\\ \vdots \\dp2[i][p-1]\end{bmatrix} = \begin{bmatrix} cnth[0] \ cnth[p-1] \ cnth[p-2] \ \dots \ cnth[1] \\cnth[1] \ cnth[0] \ cnth[p-1] \ \dots \ cnth[2] \\ \vdots \\ cnth[p-1] \ cnth[p-2] \ cnth[p-3] \ \dots \ cnth[0] \end{bmatrix} \times \begin{bmatrix} dp2[i-1][0] \\ dp2[i-1][1]\\ \vdots \\dp2[i-1][p-1] \end{bmatrix}\)

转移矩阵的第i行第j列为cnth[(i-j+p)%p]

注意\(dp1[0][i]\)的初始值为cntm[i]

所以

\(\begin{bmatrix} dp1[n][0] \\ dp1[n][1]\\ \vdots \\dp1[n][p-1]\end{bmatrix} = \begin{bmatrix} cntm[0] \ cntm[p-1] \ cntm[p-2] \ \dots \ cntm[1] \\cntm[1] \ cntm[0] \ cntm[p-1] \ \dots \ cntm[2] \\ \vdots \\ cntm[p-1] \ cntm[p-2] \ cntm[p-3] \ \dots \ cntm[0] \end{bmatrix}^{n-1} \times \begin{bmatrix} cntm[0] \\ cntm[1]\\ \vdots \\cntm[p-1] \end{bmatrix}\)

\(\begin{bmatrix} dp2[n][0] \\ dp2[n][1]\\ \vdots \\dp2[n][p-1]\end{bmatrix} = \begin{bmatrix} cnth[0] \ cnth[p-1] \ cnth[p-2] \ \dots \ cnth[1] \\cnth[1] \ cnth[0] \ cnth[p-1] \ \dots \ cnth[2] \\ \vdots \\ cnth[p-1] \ cnth[p-2] \ cnth[p-3] \ \dots \ cnth[0] \end{bmatrix}^{n-1} \times \begin{bmatrix} cnth[0] \\ cnth[1]\\ \vdots \\cnth[p-1] \end{bmatrix}\)

时间复杂度为\(O(m+p^3 \log n)\)

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define maxn 105
#define maxm 20000005
#define mod 20170408
using namespace std;
int n,m,p; 

int cnt=0;
int vis[maxm];
int prime[maxm];
void sieve(int n){
	vis[1]=1;
	for(int i=2;i<=n;i++){
		if(!vis[i]){
			prime[++cnt]=i;
		}
		for(int j=1;j<=cnt&&(long long)i*prime[j]<=(long long)n;j++){
			vis[i*prime[j]]=1;
			if(i%prime[j]==0) break;
		}
	}
}

struct matrix{
	long long a[maxn][maxn];
	matrix(){
		memset(a,0,sizeof(a));
	}
	friend matrix operator * (matrix a,matrix b){
		matrix c;
		for(int i=0;i<p;i++){
			for(int j=0;j<p;j++){
				c.a[i][j]=0;
				for(int k=0;k<p;k++){
					c.a[i][j]+=a.a[i][k]*b.a[k][j]%mod;
					c.a[i][j]%=mod;
				}
			}
		}
		return c;
	}
};

matrix fast_pow(matrix x,int k){
	matrix ans;
	for(int i=0;i<p;i++){
		ans.a[i][i]=1;
	}
	while(k>0){
		if(k&1) ans=ans*x;
		x=x*x;
		k>>=1;
	}
	return ans;
}

int cntm[maxn],cnth[maxn];
matrix A,B;
int main(){
	scanf("%d %d %d",&n,&m,&p);
	sieve(m);
	for(int i=1;i<=m;i++){
		cntm[i%p]++;
	}
	for(int i=1;i<=m;i++){
		if(vis[i]) cnth[i%p]++;
	}
	for(int i=0;i<p;i++){
		for(int j=0;j<p;j++){
			A.a[i][j]=cntm[(i-j+p)%p];
			B.a[i][j]=cnth[(i-j+p)%p];
		}
	}
	long long ans1=0,ans2=0;
	A=fast_pow(A,n-1);
	B=fast_pow(B,n-1);
	for(int i=0;i<p;i++){
		ans1+=cntm[i]*A.a[0][i];
		ans1%=mod;
		ans2+=cnth[i]*B.a[0][i];
		ans2%=mod;
	}
	printf("%lld\n",(ans1-ans2+mod)%mod);
}


posted @ 2019-01-29 12:00  birchtree  阅读(193)  评论(0编辑  收藏  举报