SDOI2015 序列统计
Description
小C有一个集合\(S\),里面的元素都是小于\(M\)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为\(N\)的数列,
数列中的每个数都属于集合\(S\)。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:
给定整数\(x\),求所有可以生成出的,且满足数列中所有数的乘积\(\mod M\)的值等于\(x\)的不同的数列的有多少个。
小C认为,两个数列\(\{A_i\}\)和\(\{B_i\}\)不同,当且仅当至少存在一个整数\(i\),满足\(A_i\neq B_i\)。
另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案\(\mod 1004535809\)的值就可以了。
Input
一行,四个整数,\(N、M、x、|S|\),其中\(|S|\)为集合\(S\)中元素个数。
第二行,\(|S|\)个整数,表示集合\(S\)中的所有元素。
\(1 \leq N \leq 10^9,3 \leq M \leq 8000\),M为质数
\(0 \leq x \leq M-1\),输入数据保证集合S中元素不重复\(x \in [1,m-1]\)
集合中的数$ \in [0,m-1]$
Output
一行,一个整数,表示你求出的种类数\(\mod 1004535809\)的值。
Solution
看到这题。。首先很容易列出一个DP转移方程
令\(F_{i,j}\)表示选了\(i\)个数字,当前乘积为\(j\)的种类数
我们发现它非常不优美,复杂度高达$ O (n * m^2) $
我们发现这个式子可以倍增。。于是很轻松的干掉一个n,它的复杂度变成了$ O(\log n * m^2) $
这貌似还是有点多。。考虑如何干掉一个 $ m $
咦。。这个模数貌似有点熟悉。。考虑NTT
不过这是乘法。。我们做不了NTT 。。。
考虑原根
设\(p\)为\(m\)的原根。。那么\(p\)的幂次可以表示出\([1,m)\)的所有数字————原根定义
于是DP方程变成了这样
注意。。此时\(F_{i,j}\)表示选到第\(i\)个数,大小为\(p^j\)次的方案数
再一变
我们发现这玩意长得像个卷积。。可以用NTT了
于是复杂度变成了 $ O(m \log n log m)$
Code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int n,m,x,lens,g[2000000],a[2000010],b[2000010];
int f[2000010],ans[2000010];
int fpow(int x,int k,int Mod)
{
int ans=1;
while (k)
{
if (k&1) ans=1LL*ans*x%Mod;
x=1LL*x*x%Mod;
k>>=1;
}
return ans;
}
namespace GetRoot //求原根
{
int prime[1000000],cnt;
bool check(int x,int p)
{
for (int i=1;i<=cnt;i++)
if (fpow(x,(p-1)/(prime[i]),p)==1) return 0;
return 1;
}
int find(int p)
{
int x=p-1;
for (int i=2;i*i<=x;i++)
{
if (x%i==0)
{
prime[++cnt]=i;
while (x%i==0) x/=i;
}
}
if (x!=1) prime[++cnt]=x;
for (int i=2;;i++)
if (check(i,p)) return i;
}
}
namespace NTT
{
const int Mod=1004535809,p=3;
int n=1;
void NTT(int *a,int inv)
{
int lim=0;
while ((1<<lim)<n) lim++;
for (int i=0;i<n;i++)
{
int t=0;
for (int j=0;j<lim;j++)
if ((i>>j) & 1) t|=1<<(lim-j-1);
if (i<t) swap(a[i],a[t]);
}
for (int l=2;l<=n;l*=2)
{
int m=l/2,p0=fpow(inv?fpow(p,Mod-2,Mod):p,(Mod-1)/l,Mod);
for (int *buf=a;buf!=a+n;buf+=l)
{
int pn=1;
for (int i=0;i<m;i++)
{
int t=1LL*pn*buf[i+m]%Mod;
buf[i+m]=(buf[i]-t+Mod)%Mod;
buf[i]=(buf[i]+t)%Mod;
pn=1LL*pn*p0%Mod;
}
}
}
}
void Union(int *a,int *c,int len)
{
while (n<2*len) n<<=1;
for (int i=0;i<n;i++) b[i]=0;
for (int i=0;i<len;i++) b[i]=c[i];
NTT(a,0);NTT(b,0);
for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%Mod;
NTT(a,1);
int invn=fpow(n,Mod-2,Mod);
for (int i=0;i<n;i++) a[i]=1LL*a[i]*invn%Mod;
for (int i=len-1;i<n;i++) a[i%(len-1)]=(a[i%(len-1)]+a[i])%Mod,a[i]=0;
}
}
void init()
{
int t=GetRoot::find(m);
for (int i=0,k=1;i<m-1;i++,k=1LL*k*t%m) g[k]=i;
x=g[x];
for (int i=1;i<=lens;i++)
if (a[i]) f[g[a[i]]]++; //若a[i]=0.就直接舍弃。
}
void solve() //倍增优化
{
int k=n;
ans[0]=1;
while (k)
{
if (k&1) NTT::Union(ans,f,m);
NTT::Union(f,f,m);
k>>=1;
}
printf("%d\n",ans[x]);
}
int main()
{
scanf("%d%d%d%d",&n,&m,&x,&lens);
for (int i=1;i<=lens;i++) scanf("%d",&a[i]);
init();
solve();
return 0;
}