BZOJ 3992 [SDOI2015]序列统计
数列长度到了109,转移矩阵边长n到了8000,除了FFT还能怎么写??!!
当然,这题由于取模,必须用NTT.
同时由于取得是乘积,所以用m的原根来搞,每次NTT完了,把后面的部分加到前面去.
注意,X不会出现0,因此一旦S集合中出现0,删掉.原根判不了0.
#include<iostream> #include<cstdio> #include<cstring> #include<cstdlib> #include<string> #include<cmath> #include<ctime> #include<algorithm> #include<map> #include<set> #include<queue> #include<iomanip> using namespace std; #define ll long long #define db double #define up(i,j,n) for(ll i=j;i<=n;i++) #define pii pair<ll,ll> #define uint unsigned ll #define FILE "dealing" ll read(){ ll x=0,f=1,ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } template<class T> bool cmax(T& a,T b){return a<b?a=b,true:false;} template<class T> bool cmin(T& a,T b){return a>b?a=b,true:false;} const ll maxn=400100,limit=128,inf=1000000000,r=3,mod=1004535809; ll n,m,X,S,G; ll a[maxn],b[maxn],id[maxn]; ll fast(ll a,ll b,ll mod){ ll ans=1; while(b){ if(b&1)ans=ans*a%mod; b>>=1; a=a*a%mod; } return ans; } void print(ll* a,ll len){ up(i,0,len-1)printf("%lld ",a[i]); cout<<endl; } namespace prepare{//求M原根 const ll maxn=8080; ll b[maxn],prime[maxn],tail,q[maxn],head; void getprime(){ for(ll i=2;i<maxn;i++){ if(!b[i])prime[++tail]=i; for(ll j=1;prime[j]*i<maxn&&j<=tail;j++){ b[i*prime[j]]=1; if(i%prime[j]==0)break; } } } ll solve(ll N){ getprime();ll p=N-1;N--; for(ll i=1;i<=tail;i++){ if(N==1)break; if(N%prime[i]==0)q[++head]=prime[i]; while(N%prime[i]==0) N/=prime[i]; } for(ll i=2;i<=p;i++){ bool flag=0; for(ll j=1;j<=head;j++) if(fast(i,p/(q[j]),p+1)==1)flag=1; if(!flag)return i; } return 0; } }; namespace NTT{ ll R[maxn],a[maxn],b[maxn],w[maxn]; ll H,L; void NTT(ll* a,ll flag){ for(ll i=0;i<L;i++)if(i<R[i])swap(a[i],a[R[i]]); for(ll len=2;len<=L;len<<=1){ ll g=fast(r,(mod-1)/len,mod),l=len>>1; if(flag)g=fast(g,mod-2,mod); w[0]=1;up(i,1,l)w[i]=w[i-1]*g%mod; for(ll st=0;st<L;st+=len) for(ll k=0;k<l;k++){ ll x=a[st+k],y=w[k]*a[st+k+l]%mod; a[st+k]=(x+y)%mod,a[st+k+l]=(x-y+mod)%mod; } } if(flag){ ll inv=fast(L,mod-2,mod); up(i,0,L-1)a[i]=a[i]*inv%mod; } } void solve(ll* c,ll* d,ll n,ll m,ll* ch){ n++,m++; up(i,0,n-1)a[i]=c[i]; up(i,0,m-1)b[i]=d[i]; for(H=0,L=1;L<n+m-1;H++)L<<=1; up(i,n,L)a[i]=0; up(i,m,L)b[i]=0; up(i,1,L)R[i]=(R[i>>1]>>1)|((i&1)<<(H-1)); NTT(a,0);NTT(b,0); up(i,0,L-1)a[i]=a[i]*b[i]%mod; NTT(a,1); up(i,0,n+m-2)ch[i]=a[i]; } }; ll c[maxn],ans[maxn],tmp[maxn]; int main(){ freopen(FILE".in","r",stdin); freopen(FILE".out","w",stdout); n=read();m=read();X=read(),S=read(); up(i,1,S)a[i]=read(); G=prepare::solve(m); ll w=1; up(i,0,m-1){ id[w]=i; w=w*G%m; } up(i,1,S)if(a[i])a[i]=id[a[i]];X=id[X]; up(i,1,S)if(a[i])c[a[i]]++; ans[0]=1; while(n){ if(n&1){ NTT::solve(ans,c,m,m,tmp); up(i,0,m-1)ans[i]=tmp[i]; up(i,1,m)ans[i]=(ans[i]+tmp[i+m-1])%mod; } n>>=1; NTT::solve(c,c,m,m,tmp); up(i,0,m-1)c[i]=tmp[i]; up(i,1,m)c[i]=(c[i]+tmp[i+m-1])%mod; } printf("%lld\n",ans[X]); return 0; }