【BZOJ】3992: [SDOI2015]序列统计 NTT+生成函数
【题意】给定一个[0,m-1]范围内的数字集合S,从中选择n个数字(可重复)构成序列。给定x,求序列所有数字乘积%m后为x的序列方案数%1004535809。1<=n<=10^9,3<=m<=8000,m为素数,1<=x<=m-1。(个人认为题意修改错误)
【算法】NTT+生成函数+离散对数+快速幂
【题解】由Πai=x(%m),可得Σlog ai=log x(%(m-1)),其中log以m的原根g为底。
所以通过将集合S和x对m取离散对数,将乘积转化为和,从而方便生成函数运算。
定义,信息为数字和,选择项为数字个数。
对于1个数字,若转化后的S中存在x,则f(x)=1,否则f(x)=0。
那么ans=f^n(x),使用以NTT为乘法运算的快速幂即可。
注意:
1.每次NTT后,将>=m-1的部分叠加到%(m-1)的位置。
2.每次dft会改变原数组,所以要提前复制一份。
3.若集合S中有数字0,无视。
#include<cstdio> #include<algorithm> #include<cmath> using namespace std; const int maxn=30010,MOD=1004535809; int a[maxn],b[maxn],tot,n,X,p,K,logs[maxn],f[maxn],ans[maxn],c[maxn]; void exgcd(int a,int b,int &x,int &y){ if(!b){x=1;y=0;} else{exgcd(b,a%b,y,x);y-=x*(a/b);} } int inv(int a){int x,y;exgcd(a,MOD,x,y);return (x%MOD+MOD)%MOD;} int power(int x,int k,int p){ int ans=1; while(k){ if(k&1)ans=1ll*ans*x%p; x=1ll*x*x%p; k>>=1; } return ans; } namespace ntt{ int o[maxn],oi[maxn]; void init(int n){ int g=3,x=power(g,(MOD-1)/n,MOD); for(int i=0;i<n;i++){ o[i]=(i==0?1:1ll*o[i-1]*x%MOD); oi[i]=inv(o[i]); } } void transform(int *a,int n,int *o){ int k=0; while((1<<k)<n)k++; for(int i=0;i<n;i++){ int t=0; for(int j=0;j<k;j++)if((1<<j)&i)t|=(1<<(k-j-1)); if(i<t)swap(a[i],a[t]); } for(int l=2;l<=n;l*=2){ int m=l>>1; for(int *p=a;p!=a+n;p+=l){ for(int i=0;i<m;i++){ int t=1ll*o[n/l*i]*p[i+m]%MOD; p[i+m]=(p[i]-t+MOD)%MOD; p[i]=(p[i]+t)%MOD; } } } } void dft(int *a,int n){transform(a,n,o);} void idft(int *a,int n){ transform(a,n,oi); int nn=inv(n); for(int i=0;i<n;i++)a[i]=1ll*a[i]*nn%MOD; } void multply(int *a,int *b,int n){ for(int i=0;i<n;i++)c[i]=b[i]; dft(a,n);dft(c,n); for(int i=0;i<n;i++)a[i]=1ll*a[i]*c[i]%MOD; idft(a,n); for(int i=p-1;i<n;i++)if(a[i])a[i%(p-1)]=(a[i%(p-1)]+a[i])%MOD,a[i]=0; } } int find(int p){ int sq=(int)(sqrt(p)+0.5),P=p-1; for(int i=2;i<=sq;i++)if(P!=1){ if(P%i==0){ b[++tot]=i; while(P%i==0)P/=i; } } if(P!=1)b[++tot]=P; for(int i=2;i<=p;i++){ bool ok=1; for(int j=1;j<=tot;j++)if(power(i,(p-1)/b[j],p)==1){ok=0;break;} if(ok)return i; } return 0; } void pre_log(){ int g=find(p),x=1; for(int i=0;i<p-1;i++){ logs[x]=i; x=1ll*x*g%p; } } void POWER(){ int N=1; while(N<p+p-2)N*=2; ntt::init(N); ans[0]=1; while(K){ if(K&1)ntt::multply(ans,f,N); ntt::multply(f,f,N); K>>=1; } } int main(){ scanf("%d%d%d%d",&K,&p,&X,&n); pre_log(); int x; for(int i=1;i<=n;i++){ scanf("%d",&x); if(!x)continue; f[logs[x]]=1; } POWER(); printf("%d",ans[logs[X]]); return 0; }