先考虑最基础的dp

dp[i][j]表示

第i项的前缀积为j的方案数

然后用FFT+快速幂优化

先求出m的原根 g

每个数字都表示成g的k次

每次都乘上这个多项式

用快速幂优化就可以了

复杂度O(nlog2n)

代码:

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define Rep(i,x,y) for(int i=x;i<y;++i)
#define For(i,x,y) for(int i=x;i<=y;++i)
#define Forn(i,x,y) for(int i=x;i>=y;--i)
using namespace std;
const int N = 100005;
const int P = 1004535809;
int n,m,x,S,g;
int s[N];
int w0[N],w1[N];
int L,invL;
int id[N];
int pw(int x,int y){
    int p=1;for(;y;y>>=1,x=1ll*x*x%P) if(y&1) p=1ll*p*x%P;return p;
}
int inv(int x){
    return pw(x,P-2);
}
void init_w(int n){
    w0[0]=w1[0]=1;
    w0[1]=pw(3,(P-1)/n);
    For(i,2,n) w0[i]=1ll*w0[i-1]*w0[1]%P;
    w1[1]=inv(w0[1]);
    For(i,2,n) w1[i]=1ll*w1[i-1]*w1[1]%P;
}
int tmp[N];
int res[N],a[N];
void ntt(int n,int*buf,int beg,int step,int*w){
    if(n==1) return;
    int m=n>>1;
    ntt(m,buf,beg,step<<1,w);
    ntt(m,buf,beg+step,step<<1,w);
    Rep(i,0,m){
        int pos=i*step*2;
        tmp[i]  =(buf[beg+pos]+1ll*w[i*step]*buf[beg+pos+step]%P)%P;
        tmp[i+m]=(buf[beg+pos]-1ll*w[i*step]*buf[beg+pos+step]%P+P)%P;
    }
    Rep(i,0,n) buf[beg+i*step]=tmp[i];
}
void mul(int*a,int*b){
    ntt(L,a,0,1,w0);
    ntt(L,b,0,1,w0);
    Rep(i,0,L) a[i]=1ll*a[i]*b[i]%P;
    ntt(L,a,0,1,w1);
    ntt(L,b,0,1,w1);
    Rep(i,0,L) a[i]=1ll*a[i]*invL%P;
    Rep(i,0,L) b[i]=1ll*b[i]*invL%P;
    For(i,0,L) if(i>=m-1) {a[i%(m-1)]=(a[i%(m-1)]+a[i])%P;a[i]=0;}
}
void sqr(int*a){
    ntt(L,a,0,1,w0);
    Rep(i,0,L) a[i]=1ll*a[i]*a[i]%P;
    ntt(L,a,0,1,w1);
    Rep(i,0,L) a[i]=1ll*a[i]*invL%P;
    For(i,0,L) if(i>=m-1) {a[i%(m-1)]=(a[i%(m-1)]+a[i])%P;a[i]=0;}
}
bool vis[8001];
void getG(){
    for(g=2;g<m;++g){
        memset(vis,0,sizeof(vis));
        int d=1;bool flg=1;
        Rep(i,0,m-1){
            if(vis[d]){flg=0;break;}
            vis[d]=1;
            d=d*g%m;
        }
        if(flg){
            int d=1;
            Rep(i,0,m-1){
                id[d]=i;
                d=d*g%m;
            }
            break;
        }
    }
}
void work(int y){
    for(;y;y>>=1){
        if(y&1) mul(res,a); 
        sqr(a);
    }
}
int main(){
    scanf("%d%d%d%d",&n,&m,&x,&S);
    For(i,1,S) scanf("%d",s+i); 
    for(L=1;L<=m+m;L<<=1);
    getG();
    invL=inv(L);
    init_w(L);
    res[id[1]]=1;
    For(i,1,S) if(s[i]) a[id[s[i]]]++;
    work(n);
    printf("%d\n",res[id[x]]);
    return 0;
}

 

 posted on 2017-05-12 20:52  rwy  阅读(188)  评论(0编辑  收藏  举报