#NTT,原根#洛谷 3321 JZOJ 4051 [SDOI2015]序列统计

题目


分析

首先朴素dp方程
\(dp[i][j]\)表示\(i\)个数的数列乘积为\(j\)的方案
那么\(dp[i][j*a[k]\bmod m]=itself+dp[i-1][j]\)
这可以用矩阵乘法优化到\(O(m^3log_2n)\),然而考场真的不想写
其实这个方程不明显,考虑到\(n\)超级大,不是矩阵乘法就是快速幂(推测)
能不能用2的多少次方拼凑出长度为\(n\)的数列
刚刚的方程可以变为$$dp[2i][t]=\sum_{xy\bmod m==t}dp[i][x]+dp[i][y]$$
那乘法怎么办呢,用原根把乘法变成加法,那就可以用NTT做了
至于原根我不好解释,但是原根会比较小,
具体做法就是\(g^k\neq 1(k|p-1)\)那么\(g\)就是\(p\)的一个原根(准确来说应该是\(\varphi(p)\),但是这里是质数)


代码

#include <cstdio>
#include <cctype>
#include <algorithm>
#define rr register
using namespace std;
const int mod=1004535809,invG=334845270,N=8011;
int n,r[N<<2],inv1,inv2,m,G,F[N<<2],Ans[N<<2],fi[N],KsM,S,X;
inline signed iut(){
	rr int ans=0; rr char c=getchar();
	while (!isdigit(c)) c=getchar();
	while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
	return ans;
}
inline void print(int ans){
	if (ans>9) print(ans/10);
	putchar(ans%10+48);
}
inline signed ksm(int x,int y,int p){
	rr int ans=1;
	for (;y;y>>=1,x=1ll*x*x%p)
	    if (y&1) ans=1ll*ans*x%p;
	return ans;
}
inline signed Getroot(int p){
    for (rr int g=2;;++g){
        rr bool flag=1;
        for (rr int i=2;i*i<p;++i)
        if ((p-1)%i==0&&(ksm(g,i,p)==1||ksm(g,(p-1)/i,p)==1)) flag=0;
        if (flag) return g;
    }
}
inline signed mo(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline void ntt(int *f,int op){
	for (rr int i=0;i<n;++i)
    if (i<r[i]) swap(f[i],f[r[i]]);
    for (rr int p=2;p<=n;p<<=1){
    	rr int len=p>>1,w=ksm(op==1?3:invG,(mod-1)/p,mod);
    	for (rr int i=0;i<n;i+=p)
    		for (rr int j=i,t=1;j<i+len;++j,t=1ll*t*w%mod){
    			rr int z=1ll*f[len+j]*t%mod;
                f[len+j]=mo(f[j],mod-z),f[j]=mo(f[j],z);
			}
	}
}
inline void cheng(int *c,int *a,int *b){
    rr int f[N<<2],g[N<<2],C[N<<2];
    for (rr int i=0;i<n;++i) f[i]=a[i],g[i]=b[i];
    ntt(f,1),ntt(g,1); for (rr int i=0;i<n;++i) f[i]=1ll*f[i]*g[i]%mod;
	ntt(f,-1); for (rr int i=0;i<n;++i) f[i]=1ll*f[i]*inv1%mod;
    for (rr int i=0;i<m-1;++i) C[i]=mo(f[i],f[i+m-1]);
    for (rr int i=0;i<n;++i) c[i]=C[i];
}
signed main(){
    KsM=iut(),m=iut(),X=iut(),S=iut(),G=Getroot(m),fi[1]=0;
    for (rr int i=1,t=1;i<m-1;++i) fi[t=t*G%m]=i;
    for (rr int i=1;i<=S;++i){
        rr int t=iut();
        if (t) ++F[fi[t]];
    }
    Ans[0]=1;
	for (n=1;n<=2*m-2;n<<=1); inv1=ksm(n,mod-2,mod);
	for (rr int i=0;i<n;++i) r[i]=(r[i>>1]>>1)|((i&1)?n>>1:0);
    for (;KsM;KsM>>=1,cheng(F,F,F)) if (KsM&1) cheng(Ans,Ans,F);
    printf("%d",Ans[fi[X]]);
	return 0;
}
posted @ 2020-02-14 17:23  lemondinosaur  阅读(103)  评论(0编辑  收藏  举报