【BZOJ】3992: [SDOI2015]序列统计 NTT+生成函数

【题意】给定一个[0,m-1]范围内的数字集合S,从中选择n个数字(可重复)构成序列。给定x,求序列所有数字乘积%m后为x的序列方案数%1004535809。1<=n<=10^9,3<=m<=8000,m为素数,1<=x<=m-1。(个人认为题意修改错误)

【算法】NTT+生成函数+离散对数+快速幂

【题解】由Πai=x(%m),可得Σlog ai=log x(%(m-1)),其中log以m的原根g为底。

所以通过将集合S和x对m取离散对数,将乘积转化为和,从而方便生成函数运算。

定义,信息为数字和,选择项为数字个数。

对于1个数字,若转化后的S中存在x,则f(x)=1,否则f(x)=0。

那么ans=f^n(x),使用以NTT为乘法运算的快速幂即可。

注意:

1.每次NTT后,将>=m-1的部分叠加到%(m-1)的位置。

2.每次dft会改变原数组,所以要提前复制一份。

3.若集合S中有数字0,无视。

#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
const int maxn=30010,MOD=1004535809;
int a[maxn],b[maxn],tot,n,X,p,K,logs[maxn],f[maxn],ans[maxn],c[maxn];
void exgcd(int a,int b,int &x,int &y){
    if(!b){x=1;y=0;}
    else{exgcd(b,a%b,y,x);y-=x*(a/b);}
}
int inv(int a){int x,y;exgcd(a,MOD,x,y);return (x%MOD+MOD)%MOD;}
int power(int x,int k,int p){
    int ans=1;
    while(k){
        if(k&1)ans=1ll*ans*x%p;
        x=1ll*x*x%p;
        k>>=1;
    }
    return ans;
}
namespace ntt{
    int o[maxn],oi[maxn];
    void init(int n){
        int g=3,x=power(g,(MOD-1)/n,MOD);
        for(int i=0;i<n;i++){
            o[i]=(i==0?1:1ll*o[i-1]*x%MOD);
            oi[i]=inv(o[i]);
        }
    }
    void transform(int *a,int n,int *o){
        int k=0;
        while((1<<k)<n)k++;
        for(int i=0;i<n;i++){
            int t=0;
            for(int j=0;j<k;j++)if((1<<j)&i)t|=(1<<(k-j-1));
            if(i<t)swap(a[i],a[t]);
        }
        for(int l=2;l<=n;l*=2){
            int m=l>>1;
            for(int *p=a;p!=a+n;p+=l){
                for(int i=0;i<m;i++){
                    int t=1ll*o[n/l*i]*p[i+m]%MOD;
                    p[i+m]=(p[i]-t+MOD)%MOD;
                    p[i]=(p[i]+t)%MOD;
                }
            }
        }
    }
    void dft(int *a,int n){transform(a,n,o);}
    void idft(int *a,int n){
        transform(a,n,oi);
        int nn=inv(n);
        for(int i=0;i<n;i++)a[i]=1ll*a[i]*nn%MOD;
    }
    void multply(int *a,int *b,int n){
        for(int i=0;i<n;i++)c[i]=b[i];
        dft(a,n);dft(c,n);
        for(int i=0;i<n;i++)a[i]=1ll*a[i]*c[i]%MOD;
        idft(a,n);
        for(int i=p-1;i<n;i++)if(a[i])a[i%(p-1)]=(a[i%(p-1)]+a[i])%MOD,a[i]=0;
    }
}
int find(int p){
    int sq=(int)(sqrt(p)+0.5),P=p-1;
    for(int i=2;i<=sq;i++)if(P!=1){
        if(P%i==0){
            b[++tot]=i;
            while(P%i==0)P/=i;
        }
    }
    if(P!=1)b[++tot]=P;
    for(int i=2;i<=p;i++){
        bool ok=1;
        for(int j=1;j<=tot;j++)if(power(i,(p-1)/b[j],p)==1){ok=0;break;}
        if(ok)return i;
    }
    return 0;
}
void pre_log(){
    int g=find(p),x=1;
    for(int i=0;i<p-1;i++){
        logs[x]=i;
        x=1ll*x*g%p;
    }
}
void POWER(){
    int N=1;
    while(N<p+p-2)N*=2;
    ntt::init(N);
    ans[0]=1;
    while(K){
        if(K&1)ntt::multply(ans,f,N);
        ntt::multply(f,f,N);
        K>>=1;
    }
}
int main(){
    scanf("%d%d%d%d",&K,&p,&X,&n);
    pre_log();
    int x;
    for(int i=1;i<=n;i++){
        scanf("%d",&x);
        if(!x)continue;
        f[logs[x]]=1;
    }
    POWER();
    printf("%d",ans[logs[X]]);
    return 0;
}
View Code

 

posted @ 2018-02-22 19:56  ONION_CYC  阅读(336)  评论(0编辑  收藏  举报