P3321 [SDOI2015]序列统计
https://www.luogu.com.cn/problem/P3321
暴力 dp 的话,就是 \(f_{i,x}\) 表示填了前 \(i\) 个数,乘积为 \(x\) 的有多少种,那么 \(f_{i,x}\rightarrow f_{i+1,x\cdot S_k}\)
发现如果把后面下标里的那个乘改成加就是普通的循环卷积
而原根有性质 \(a^x\) 可以取遍 \(0,1,\cdots,p-1\) 的所有数(其中 \(a\) 为 \(p\) 的原根,\(x\in [0,p-2]\))
所以只要找到 \(m\) 的原根,把它当作底数来在膜意义下取 \(\log\) 即可变成加号
最小原根是不会大于 \(p^{0.25}\) 级别的,所以直接从小到大枚举并判断
设 \(k\) 是 \(p-1\) 的一个质因数,那么若对于任意的 \(k\),都有 \(a^{\frac{p-1}{k}}\not \equiv 1\bmod p\),那么 \(a\) 就是 \(p\) 的一个原根
那么每个 \(S_i\) 就可以为产生的那个多项式的第 \(\log S_i\) 项加一
因为指数需要模 \(\varphi(m)\) 也就是 \(m-1\),所以产生的这个多项式是 \(m-1\) 次的
那么对这个多项式做 \(n\) 次方即可
因为是循环卷积,所以需要像普通快速幂那样做一个 \(O(m\log m\log n)\) 的东西,而不能写那种先取 \(\log\) 再 \(\operatorname{exp}\) 的一 \(\log\) 做法
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<assert.h>
#define getChar getchar()
inline int read(){
register int x=0,y=1;register char c=getChar;
while(c<'0'||c>'9') y&=(c!='-'),c=getChar;
while(c>='0'&&c<='9') x=x*10+(c^48),c=getChar;
return y?x:-x;
}
#define N 8006
#define G 3
#define MOD 1004535809
inline long long power(long long a,long long b,int mod=MOD){
long long ans=1;
while(b){
if(b&1) ans=ans*a%mod;
b>>=1;a=a*a%mod;
}
return ans;
}
inline int getFac(int n,int *fac){
int o=0;
for(int i=2;i*i<=n;i++)if(!(n%i)){
fac[++o]=i;
while(!(n%i)) n/=i;
}
if(n>1) fac[++o]=n;
return o;
}
inline int getRoot(int p){
static int fac[N];
int n=getFac(p-1,fac);
for(int i=2;;i++){
for(int j=1;j<=n;j++)if(power(i,(p-1)/fac[j],p)==1) goto NEX;
return i;
NEX:;
}
return -1;
}
int rev[N*4];
inline int init(int n){
int max=1;while(max<n) max<<=1;
for(int i=0;i<max;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(max>>1):0;
return max;
}
inline void ntt(int n,long long *a,int type){
for(int i=0;i<n;i++)if(rev[i]<i) std::swap(a[i],a[rev[i]]);
for(int h=1;h<n;h<<=1){
long long gn=power(G,(MOD-1)/(h<<1)),g,o;
if(!type) gn=power(gn,MOD-2);
for(int i=0,j;i<n;i+=h<<1){
for(g=1,j=i;j<i+h;j++,g=g*gn%MOD){
o=g*a[j+h]%MOD;
a[j+h]=(a[j]-o+MOD)%MOD;a[j]=(a[j]+o)%MOD;
}
}
}
if(!type){
long long inv=power(n,MOD-2);
for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD;
}
}
inline void shrink(int len,int n,long long *f){
for(int i=0;i<n;i++) f[i]=(f[i]+f[i+n])%MOD;
std::memset(f+n,0,(len-n)*sizeof f[0]);
}
inline void mul(int len,int n,long long *f,long long *g=NULL){
ntt(len,f,1);
if(g) ntt(len,g,1);
if(g) for(int i=0;i<len;i++) f[i]=f[i]*g[i]%MOD;
else for(int i=0;i<len;i++) f[i]=f[i]*f[i]%MOD;
ntt(len,f,0);shrink(len,n,f);
if(g) ntt(len,g,0);
}
inline void power(int n,long long *f,int b){
int len=init(n*2);
static long long ans[N*4];
std::memset(ans,0,len*sizeof ans[0]);ans[0]=1;
while(b){
if(b&1) mul(len,n,ans,f);
mul(len,n,f);b>>=1;
}
std::memcpy(f,ans,sizeof ans);
}
int _log[N];
int main(){
int m=read(),n=read(),x=read(),s=read();
static int a[N];
for(int i=1;i<=s;i++) a[i]=read();
int root=getRoot(n);
for(int i=0,x=1;i<n-1;i++,x=(long long)x*root%n) _log[x]=i;
static long long f[N*4];
for(int i=1;i<=s;i++)if(a[i]) f[_log[a[i]]]++;
power(n-1,f,m);
printf("%lld\n",f[_log[x]]);
return 0;
}