SDOI2015 序列统计

Description

小C有一个集合\(S\),里面的元素都是小于\(M\)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为\(N\)的数列,

数列中的每个数都属于集合\(S\)。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:

给定整数\(x\),求所有可以生成出的,且满足数列中所有数的乘积\(\mod M\)的值等于\(x\)的不同的数列的有多少个。

小C认为,两个数列\(\{A_i\}\)\(\{B_i\}\)不同,当且仅当至少存在一个整数\(i\),满足\(A_i\neq B_i\)

另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案\(\mod 1004535809\)的值就可以了。

Input

一行,四个整数,\(N、M、x、|S|\),其中\(|S|\)为集合\(S\)中元素个数。

第二行,\(|S|\)个整数,表示集合\(S\)中的所有元素。

\(1 \leq N \leq 10^9,3 \leq M \leq 8000\),M为质数

\(0 \leq x \leq M-1\),输入数据保证集合S中元素不重复\(x \in [1,m-1]\)

集合中的数$ \in [0,m-1]$

Output

一行,一个整数,表示你求出的种类数\(\mod 1004535809\)的值。

Solution

看到这题。。首先很容易列出一个DP转移方程

\(F_{i,j}\)表示选了\(i\)个数字,当前乘积为\(j\)的种类数

\[F_{i,j}=\sum_{a+b \equiv j (\mod m) } {F_{i-1,a}*F_{i-1,b}} \]

我们发现它非常不优美,复杂度高达$ O (n * m^2) $

我们发现这个式子可以倍增。。于是很轻松的干掉一个n,它的复杂度变成了$ O(\log n * m^2) $

这貌似还是有点多。。考虑如何干掉一个 $ m $

咦。。这个模数貌似有点熟悉。。考虑NTT

不过这是乘法。。我们做不了NTT 。。。

考虑原根

\(p\)\(m\)的原根。。那么\(p\)的幂次可以表示出\([1,m)\)的所有数字————原根定义

于是DP方程变成了这样

\[F_{i,j}=\sum_{g^a + g^b \equiv g^j (\mod m)}{F_{i-1,a}*F_{i-1,b}} \]

注意。。此时\(F_{i,j}\)表示选到第\(i\)个数,大小为\(p^j\)次的方案数

再一变

\[F_{i,j}=\sum_{a+b \equiv j (\mod m-1)}{F_{i-1,a}*F_{i-1,b}} \]

我们发现这玩意长得像个卷积。。可以用NTT了

于是复杂度变成了 $ O(m \log n log m)$

Code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int n,m,x,lens,g[2000000],a[2000010],b[2000010];
int f[2000010],ans[2000010];
int fpow(int x,int k,int Mod)
{
	int ans=1;
	while (k)
	{
		if (k&1) ans=1LL*ans*x%Mod;
		x=1LL*x*x%Mod;
		k>>=1;
	}
	return ans;
}
namespace GetRoot //求原根
{
	int prime[1000000],cnt;
	bool check(int x,int p)
	{
		for (int i=1;i<=cnt;i++)
			if (fpow(x,(p-1)/(prime[i]),p)==1) return 0;
		return 1;
	}
	int find(int p)
	{
		int x=p-1;
		for (int i=2;i*i<=x;i++)
		{
			if (x%i==0)
			{
				prime[++cnt]=i;
				while (x%i==0) x/=i; 
			}
		}
		if (x!=1) prime[++cnt]=x;
		for (int i=2;;i++)
			if (check(i,p)) return i;
	}
}
namespace NTT
{
	const int Mod=1004535809,p=3;
	int n=1;
	void NTT(int *a,int inv)
	{
		int lim=0;
		while ((1<<lim)<n) lim++;
		for (int i=0;i<n;i++)
		{
			int t=0;
			for (int j=0;j<lim;j++)
				if ((i>>j) & 1) t|=1<<(lim-j-1);
			if (i<t) swap(a[i],a[t]);
		}
		for (int l=2;l<=n;l*=2)
		{
			int m=l/2,p0=fpow(inv?fpow(p,Mod-2,Mod):p,(Mod-1)/l,Mod);
			for (int *buf=a;buf!=a+n;buf+=l)
			{
				int pn=1;
				for (int i=0;i<m;i++)
				{
					int t=1LL*pn*buf[i+m]%Mod;
					buf[i+m]=(buf[i]-t+Mod)%Mod;
					buf[i]=(buf[i]+t)%Mod;
					pn=1LL*pn*p0%Mod;
				}
			}
		}
}
	void Union(int *a,int *c,int len)
	{
		while (n<2*len) n<<=1;
		for (int i=0;i<n;i++) b[i]=0;
		for (int i=0;i<len;i++) b[i]=c[i];
		NTT(a,0);NTT(b,0);
		for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%Mod;
		NTT(a,1);
		int invn=fpow(n,Mod-2,Mod);
		for (int i=0;i<n;i++) a[i]=1LL*a[i]*invn%Mod;
		for (int i=len-1;i<n;i++) a[i%(len-1)]=(a[i%(len-1)]+a[i])%Mod,a[i]=0;
	}
}
void init()
{
	int t=GetRoot::find(m);
	for (int i=0,k=1;i<m-1;i++,k=1LL*k*t%m) g[k]=i;
	x=g[x];
	for (int i=1;i<=lens;i++)
		if (a[i]) f[g[a[i]]]++;	 //若a[i]=0.就直接舍弃。
}
void solve() //倍增优化
{
	int k=n;
	ans[0]=1;
	while (k)
	{
		if (k&1) NTT::Union(ans,f,m);
		NTT::Union(f,f,m);
		k>>=1;
	}
	printf("%d\n",ans[x]);
}
int main()
{
	scanf("%d%d%d%d",&n,&m,&x,&lens);
	for (int i=1;i<=lens;i++) scanf("%d",&a[i]);
	init();
	solve();
	return 0;
}
posted @ 2019-01-12 22:55  Starryskies  阅读(149)  评论(0编辑  收藏  举报