【bzoj3992】 SDOI2015—序列统计

http://www.lydsy.com/JudgeOnline/problem.php?id=3992 (题目链接)

题意

  集合${S}$中有若干个不超过${m}$的非负整数,问由这些数组成一个长度${n}$的序列,使序列中的数的乘积对${m}$取模正好等于${x}$,问存在多少方案。

Solution

  好神的题。算法还是要多复习,我连${NTT}$都忘记怎么写了T_T

  这还是我的第一发原根→_→。

  一个数如果有原根,那么它会有很多原根,所以如果对时间没有特殊限制,我们枚举${rt=2~~to~~inf}$,然后判断是否存在${t<m-1}$使${rt^t=1}$。虽然我并不知道为什么可以那样check。。

  我们可以很简单的列出dp方程${f_{i,j}}$表示,已经放到了第${i}$个数,它们的乘积是${j}$的方案数。转移也就很显然了:$${f[i][j]=\sum_{k=1}^{m-1}f_{i-1,j*inv[k]}}$$

  复杂度${O(nm^2)}$,于是我们就可以获得10分的高分,是不是很良心啊。

  考虑这个东西怎么优化,我们把每一个${j}$都写成${m}$的原根的几次方,然后乘就变成加辣,然后我们就可以卷积辣。

  然后你发现${n}$有${10^9}$,我们快速幂一波,然后就AC辣。

细节

  一开始没想清没注意到还是循环卷积卧槽T_T

代码

// bzoj3992
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<ctime>
#define LL long long
#define inf (1ll<<30)
#define MOD 1004535809
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;

const int maxn=20010;
int f[maxn],g[maxn],rev[maxn],vis[maxn];
int n,m,rt,S,X,N,L;

int power(int a,int b,int c) {
	int res=1;
	while (b) {
		if (b&1) res=(LL)res*a%c;
		b>>=1;a=(LL)a*a%c;
	}
	return res;
}
void root(int p) {
	if (p==2) {rt=1;return;}
	for (rt=2;;rt++) {
		int flag=1;
		for (int i=2;i*i<p;i++)
			if (power(rt,(p-1)/i,p)==1) {flag=0;break;}
		if (flag) break;
	}
}
namespace NTT {
	LL A[maxn],B[maxn];
	void NTT(LL *a,int f) {
		for (int i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
		for (int i=1;i<N;i<<=1) {
			LL gn=power(3,(MOD-1)/(i<<1),MOD);
			for (int p=i<<1,j=0;j<N;j+=p) {
				LL g=1;
				for (int k=0;k<i;k++,(g*=gn)%=MOD) {
					LL x=a[k+j],y=g*a[k+j+i]%MOD;
					a[k+j]=(x+y)%MOD,a[k+j+i]=(x-y+MOD)%MOD;
				}
			}
		}
		if (f==-1) reverse(a+1,a+N);
	}
	void Init(int *a,int *b) {
		for (int i=0;i<N;i++) A[i]=a[i],B[i]=b[i];
		NTT(A,1);NTT(B,1);
		for (int i=0;i<N;i++) (A[i]*=B[i])%=MOD;
		NTT(A,-1);
		LL ev=power(N,MOD-2,MOD);
		for (int i=0;i<N;i++) (A[i]*=ev)%=MOD;
		for (int i=0;i<m-1;i++) a[i]=(A[i]+A[i+m-1])%MOD;
	}
}
using namespace NTT;

int main() {
	scanf("%d%d%d%d",&n,&m,&X,&S);
	root(m);
	for (int x,i=1;i<=S;i++) scanf("%d",&x),vis[x]=1;
	for (int p=1,i=0;i<m-1;i++,(p*=rt)%=m) if (vis[p]) f[i]=1;
	for (N=1,L=-1;N<(m-1)*2;N<<=1) L++;
	for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<L);
	g[0]=1;
	while (n) {
		if (n&1) Init(g,f);
		n>>=1;Init(f,f);
	}
	for (int i=0,p=1;i<m-1;i++,(p*=rt)%=m)
		if (p==X) {printf("%d",g[i]);break;}
	return 0;
}

 

posted @ 2017-02-13 23:06  MashiroSky  阅读(411)  评论(0编辑  收藏  举报