BZOJ 3992 序列统计
Description
小C有一个集合\(S\),里面的元素都是小于\(M\)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为\(N\)的数列,数列中的每个数都属于集合\(S\)。
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数\(x\),求所有可以生成出的,且满足数列中所有数的乘积\(mod\;M\)的值等于\(x\)的不同的数列的有多少个。小C认为,两个数列\(\lbrace A_{i} \rbrace\)和\(\lbrace B_{i} \rbrace\)不同,当且仅当至少存在一个整数\(i\),满足\(A_{i} \ne B_{i}\)。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案\(mod\;1004535809\)的值就可以了。
Input
一行,四个整数,\(N,M,x,\mid S \mid\),其中\(\mid S \mid\)为集合\(S\)中元素个数。第二行,\(\mid S \mid\)个整数,表示集合\(S\)中的所有元素。
Output
一行,一个整数,表示你求出的种类数\(mod\;1004535809\)的值。
Sample Input
4 3 1 2
1 2
Sample Output
8
HINT
对于\(10\%\)的数据,\(1 \le N \le 1000\);
对于\(30\%\)的数据,\(3 \le M \le 100\);
对于\(60\%\)的数据,\(3 \le M \le 800\);
对于全部的数据,\(1 \le N \le 10^{9}\),\(3 \le M \le 8000\),\(M\)为质数,\(1 \le x \le M-1\),输入数据保证集合\(S\)中元素不重复。
这题让我知道了原根有什么用。
由于\(M\)是质数,所以\(M\)一定有原根\(G\)。我们只要知道\(1 \sim M-1\)在\(mod \; M\)意义下\(G\)的离散对数就可以把乘法化成了加法。
有了这个之后,我们就可以得到一个生成多项式。观察模数\(1004535809 = 2^{21} \times 479+1\),所以可以使用NTT(快速数论变换,就是把FFT的单位复根变成\(1004535809\)的原根)+快速幂。(把这个题目告诉rhl,rhl直接秒,我太弱了TAT)
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cstdlib>
using namespace std;
#define gg (3)
#define rhl (1004535809)
#define maxm (16010)
typedef long long ll;
int n,m,x,S,G,tot,factor[maxm],pos[maxm],e[25],ine[25];
inline ll qsm(ll a,ll b,int c)
{
ll ret = 1;
for (;b;b >>= 1,(a *= a) %= c) if (b & 1) (ret *= a) %= c;
return ret;
}
struct node
{
int a[maxm*2],len;
inline node() { memset(a,0,sizeof(a)); }
inline void NTT(int loglen,int len,int on)
{
for (int i = 0,j,t,p;i < len;++i)
{
for (j = 0,t = i,p = 0;j < loglen;++j,t >>= 1) p <<= 1,p |= t & 1;
if (p < i) swap(a[p],a[i]);
}
for (int s = 1,k = 2;s <= loglen;++s,k <<= 1)
{
int wn; if (on) wn = e[s]; else wn = ine[s];
for (int i = 0;i < len;i += k)
{
int w = 1;
for (int j = 0;j < (k >> 1);++j,w = (ll)wn*w%rhl)
{
int u = a[i+j],v = (ll)w*a[i+j+(k>>1)]%rhl;
a[i+j] = u+v; if (a[i+j] >= rhl) a[i+j] -= rhl;
a[i+j+(k>>1)] = u-v; if (a[i+j+(k>>1)] < 0) a[i+j+(k>>1)] += rhl;
}
}
}
if (!on)
{
int inv = qsm(len,rhl-2,rhl);
for (int i = 0;i < len;++i) a[i] = (ll)a[i]*inv%rhl;
}
}
friend inline node operator *(node x,node y)
{
int loglen = 0,len;
for (;(1<<loglen)<x.len+y.len;++loglen); len = 1<<loglen;
x.NTT(loglen,len,1); y.NTT(loglen,len,1);
for (int i = 0;i < (1<<loglen);++i) x.a[i] = (ll)x.a[i]*y.a[i]%rhl;
x.NTT(loglen,len,0);
while (len&&(len >= m||!x.a[len-1]))
{
x.a[(len-1)%(m-1)] += x.a[len-1],x.a[--len] = 0;
if (x.a[len%(m-1)] >= rhl) x.a[len%(m-1)] -= rhl;
}
x.len = len;
return x;
}
}pa;
inline bool check(int g)
{
for (int i = 1;i <= tot;++i) if (qsm(g,(m-1)/factor[i],m) == 1) return false;
return true;
}
int main()
{
freopen("3992.in","r",stdin);
freopen("3992.out","w",stdout);
scanf("%d %d %d %d",&n,&m,&x,&S);
for (int i = 2,p = m-1;p > 1;++i)
if (!(p % i))
{
factor[++tot] = i;
while (!(p % i)) p /= i;
}
for (int i = 1;i < m;++i) if (check(i)) { G = i; break; }
for (int i = 0,now = 1;i < m-1;++i,(now *= G)%=m) pos[now] = i;
for (int i = 1;i < 20;++i) e[i] = qsm(gg,(rhl-1)>>i,rhl),ine[i] = qsm(e[i],rhl-2,rhl);
for (int i = 1,a;i <= S;++i) { scanf("%d",&a); if (a) pa.a[pos[a]]++; } pa.len = m-1;
node ans; ans.a[0] = 1; ans.len = 1;
for (;n;n >>= 1,pa = pa*pa)
if (n & 1) ans = ans*pa;
printf("%d",ans.a[pos[x]]);
fclose(stdin); fclose(stdout);
return 0;
}