bzoj 3992: [SDOI2015]序列统计【原根+生成函数+NTT+快速幂】

还是没有理解透原根……题目提示其实挺明显的,M是质数,然后1<=x<=M-1
这种计数就容易想到生成函数,但是生成函数是加法,而这里是乘法,所以要想办法变成加法
首先因为0和任何数乘都是0,和其他数规则不相符,所以不考虑(答案也没让求)
然后看原根的性质,设g是M的原根,那么\( g^i%M 0<=i<M-1 \)就是1~M-1的不重集合,所以可以把乘法变成原根指数的加法,这样就变成多项式乘法了,可以用NTT优化
然后n非常大,所以使用快速幂进行多项式乘法

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=20005,mod=1004535809,g=3;
int n,m,x,k,s[N],d=2,id[N],lm,bt,re[N];
long long a[N],c[N],r[N];
int read()
{
	int r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
void jia(long long &x,long long &y)
{
	x+=y;
	x>=mod?x-=mod:0;
}
long long ksm(long long a,long long b,int mod)
{
	long long r=1;
	while(b)
	{
		if(b&1)
			r=r*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return r;
}
void dft(long long a[],int f)
{
	for(int i=0;i<lm;i++)
		if(i<re[i])
			swap(a[i],a[re[i]]);
	for(int i=1;i<lm;i<<=1)
	{
		long long wi=ksm(g,(mod-1)/(2*i),mod);
		if(f==-1)
			wi=ksm(wi,mod-2,mod);
		for(int k=0;k<lm;k+=(i<<1))
		{
			long long w=1,x,y;
			for(int j=0;j<i;j++)
			{
				x=a[j+k],y=1ll*w*a[i+j+k]%mod;
				a[j+k]=((x+y)%mod+mod)%mod,a[i+j+k]=((x-y)%mod+mod)%mod;
				w=w*wi%mod;
			}
		}
	}
	if(f==-1)
	{
		long long inv=ksm(lm,mod-2,mod);
		for(int i=0;i<lm;i++)
			a[i]=a[i]*inv%mod;
	}
}
void ntt(long long a[],long long b[])
{
	for(int i=0;i<lm;i++)
		c[i]=b[i];
	dft(a,1);
	dft(c,1);
	for(int i=0;i<lm;i++)
		a[i]=a[i]*c[i]%mod;
	dft(a,-1);
	for(int i=m-1;i<lm;i++)
		jia(a[i%(m-1)],a[i]),a[i]=0;
	// for(int i=0;i<lm;i++)
		// cerr<<a[i]<<" ";cerr<<endl;
}
int main()
{
	n=read()-1,m=read(),x=read(),k=read();
	for(int i=1;i<=k;i++)
		s[i]=read();
	for(bool fl=0;!fl;d++)
	{
		fl=1;
		for(int i=1;i<m-1;i++)
			if(ksm(d,i,m)==1)
			{
				fl=0;
				break;
			}
		if(ksm(d,m-1,m)!=1)
			fl=0;
		if(fl)
			break;
	}
	for(int i=0;i<m-1;i++)
		id[ksm(d,i,m)]=i;//,cerr<<rl[i]<<" "<<i<<endl;
	for(int i=1;i<=k;i++)
		if(s[i])
			a[id[s[i]]]++,r[id[s[i]]]++;
	for(bt=0;(1<<bt)<=2*m;bt++);
	lm=(1<<bt);
	for(int i=0;i<lm;i++)
		re[i]=(re[i>>1]>>1)|((i&1)<<(bt-1));
	while(n)
	{
		if(n&1)
			ntt(r,a);
		ntt(a,a);
		n>>=1;
	}
	printf("%lld\n",r[id[x]]);
	return 0;
}
posted @ 2018-11-28 21:50  lokiii  阅读(193)  评论(0编辑  收藏  举报