【BZOJ3992】【SDOI2015】序列统计 原根 NTT

题目大意

  有一个集合\(s\),里面的每个数都\(\geq0\)\(<m\)

  问有多少个长度为\(n\)的数列满足这个数列所有数的乘积模\(m\)\(x\)。答案模\(1004535809\)

  \(n\leq {10}^9,m\leq 8000\)\(m\)是质数。

题解

  先求出\(m\)的原根\(g\),这样\(1\)~\(m-1\)中的每个数都能被表示成\(g\)的幂。

  因为\(g^ig^j=g^{i+j}\),这样就可以把乘积转成和,问题转化为问有多少个长度为\(n\)的数列满足这个数列所有数的和模\(m-1\)\(y\)\(f_{i+1,j}=\sum f_{i,k}f_{i,j-k}\)。因为模数是NTT模数,原根为\(3\),所以可以用NTT优化。

  ln&exp好像也可以做。

  时间复杂度:\(O(m\log m\log n)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
int m;
ll p=1004535809;
int g;
int len=16384;
ll a[100010];
ll b[100010];
ll w1[100010];
ll w2[100010];
int rev[100010];
int c[10010];
ll fp(ll a,ll b)
{
	ll s=1;
	while(b)
	{
		if(b&1)
			s=s*a%p;
		a=a*a%p;
		b>>=1;
	}
	return s;
}
int check(int x)
{
	int s=1;
	int i;
	for(i=1;i<=m-2;i++)
	{
		s=s*x%m;
		if(s==1)
			return 0;
	}
	return 1;
}
void ntt(ll *a,int t)
{
	int i,j,k;
	ll u,v,w,wn;
	for(i=0;i<len;i++)
		if(rev[i]<i)
			swap(a[i],a[rev[i]]);
	for(i=2;i<=len;i<<=1)
	{
		wn=(t==1?w1[i]:w2[i]);
		for(j=0;j<len;j+=i)
		{
			w=1;
			for(k=j;k<j+i/2;k++)
			{
				u=a[k];
				v=a[k+i/2]*w%p;
				a[k]=(u+v)%p;
				a[k+i/2]=(u-v+p)%p;
				w=w*wn%p;
			}
		}
	}
	if(t==-1)
	{
		ll inv=fp(len,p-2);
		for(i=0;i<len;i++)
			a[i]=a[i]*inv%p;
	}
}
void fp(int n)
{
	int i;
	while(n)
	{
		ntt(a,1);
		if(n&1)
		{
			ntt(b,1);
			for(i=0;i<len;i++)
				b[i]=b[i]*a[i]%p;
			ntt(b,-1);
			for(i=m;i<len;i++)
			{
				b[i%m]=(b[i%m]+b[i])%p;
				b[i]=0;
			}
		}
		for(i=0;i<len;i++)
			a[i]=a[i]*a[i]%p;
		ntt(a,-1);
		for(i=m;i<len;i++)
		{
			a[i%m]=(a[i%m]+a[i])%p;
			a[i]=0;
		}
		n>>=1;
	}
}
int main()
{
	int n,x,u;
	scanf("%d%d%d%d",&n,&m,&x,&u);
	int i;
	for(i=1;i<m;i++)
		if(check(i))
			g=i;
	rev[0]=0;
	for(i=1;i<len;i++)
		rev[i]=(rev[i/2]>>1)|(i&1?len>>1:0);
	for(i=2;i<=len;i<<=1)
	{
		w1[i]=fp(3,(p-1)/i);
		w2[i]=fp(w1[i],p-2);
	}
	int s=1;
	for(i=0;i<=m-2;i++)
	{
		c[s]=i;
		s=s*g%m;
	}
	int v;
	memset(a,0,sizeof a);
	memset(b,0,sizeof b);
	b[0]=1;
	for(i=1;i<=u;i++)
	{
		scanf("%d",&v);
		if(!v)
			continue;
		a[c[v]]++;
	}
	m--;
	fp(n);
	printf("%lld\n",b[c[x]]);
	return 0;
}
posted @ 2018-03-05 20:24  ywwyww  阅读(200)  评论(0编辑  收藏  举报