[BZOJ3992][SDOI2015]序列统计(DP+原根+NTT)
3992: [SDOI2015]序列统计
Time Limit: 30 Sec Memory Limit: 128 MB
Submit: 1888 Solved: 898
[Submit][Status][Discuss]Description
小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。Input
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。第二行,|S|个整数,表示集合S中的所有元素。1<=N<=10^9,3<=M<=8000,M为质数0<=x<=M-1,输入数据保证集合S中元素不重复x∈[1,m-1]集合中的数∈[0,m-1]
Output
一行,一个整数,表示你求出的种类数mod 1004535809的值。
Sample Input
4 3 1 2
1 2Sample Output
8
【样例说明】
可以生成的满足要求的不同的数列有(1,1,1,1)、(1,1,2,2)、(1,2,1,2)、(1,2,2,1)、
(2,1,1,2)、(2,1,2,1)、(2,2,1,1)、(2,2,2,2)HINT
Source
[Submit][Status][Discuss]
好久没有调过这么痛苦了。。
首先有一个简单的DP方程:$dp[i+j][(x*y)\%p]+=dp[i][x]*dp[j][y]$,dp[i][j]表示前i个数凑成余数为j的方案数。
$n^2$转移很简单,然后可以矩阵优化,但M的范围仍然过不了。
这时候就要敢往FFT方面去想。
现在这个方程之所以不能用FFT的原因在于x,y转移到的不是x+y而是x*y%mod,我们可以想到,乘法运算在作为指数的时候就是加法,又发现M是质数,于是我们考虑用原根的次幂替代S中的每个数。
设$g^{x'} \equiv x (mod\ p)$,$g^{y'} \equiv y (mod\ p)$,这样方程就变为$dp[i+j][(x'+y')\% \phi(p)]+=dp[i][x']*dp[j][y']$,这个就是标准的循环卷积了。
注意循环卷积次数界要再放大一倍!!还有S中要忽略0!!两个问题调到崩溃。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=l; i<=r; i++) 4 typedef long long ll; 5 using namespace std; 6 7 const int N=20100,mod=1004535809,G=3; 8 int k,p,x,n,s,g,inv,cnt[N],f[N],pw[N],ind[N],rev[N],c[N]; 9 10 int ksm(int a,int b,int p){ 11 int res; 12 for (res=1; b; a=(1ll*a*a)%p,b>>=1) 13 if (b & 1) res=(1ll*res*a)%p; 14 return res; 15 } 16 17 bool chk(){ 18 for (int i=2; i*i<=p; i++) if ((p-1)%i==0 && ksm(g,(p-1)/i,p)==1) return 0; 19 return 1; 20 } 21 22 void getroot(){ 23 if (p==2) g=1; else for (g=2; !chk(); g++); 24 ind[1]=0; pw[0]=1; 25 for (int i=1; i<p-1; i++) pw[i]=pw[i-1]*g%p,ind[pw[i]]=i; 26 } 27 28 namespace NTT{ 29 int n,L,rev[N]; 30 void init(int m){ 31 n=1; L=0; 32 for (; n<=m; n<<=1) L++; 33 n<<=1; L++; inv=ksm(n,mod-2,mod); 34 for (int i=0; i<n; i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); 35 } 36 void DFT(int a[],int n,int f){ 37 for (int i=0; i<n; i++) if (i<rev[i]) swap(a[i],a[rev[i]]); 38 for (int i=1; i<n; i<<=1){ 39 int wn=ksm(G,(f==1) ? (mod-1)/(i<<1) : (mod-1)-(mod-1)/(i<<1),mod); 40 for (int p=i<<1,j=0; j<n; j+=p){ 41 int w=1; 42 for (int k=0; k<i; k++,w=1ll*w*wn%mod){ 43 int x=a[j+k],y=1ll*w*a[i+j+k]%mod; 44 a[j+k]=(x+y)%mod; a[i+j+k]=(x-y+mod)%mod; 45 } 46 } 47 } 48 if (f==1) return; 49 for (int i=0; i<n; i++) a[i]=1ll*a[i]*inv%mod; 50 } 51 void mul(int a[],int b[]){ 52 for (int i=0; i<n; i++) c[i]=b[i]; 53 DFT(a,n,1); DFT(c,n,1); 54 for (int i=0; i<n; i++) a[i]=1ll*a[i]*c[i]%mod; 55 DFT(a,n,-1); 56 for (int i=n-1; i>=p-1; i--) a[i-p+1]=(a[i-p+1]+a[i])%mod,a[i]=0; 57 } 58 }; 59 60 int main(){ 61 freopen("bzoj3992.in","r",stdin); 62 freopen("bzoj3992.out","w",stdout); 63 scanf("%d%d%d%d",&k,&p,&x,&n); getroot(); NTT::init(p); 64 rep(i,1,n) { scanf("%d",&s); if (s) cnt[ind[s]]++; } 65 for (f[0]=1; k; k>>=1,NTT::mul(cnt,cnt)) if (k & 1) NTT::mul(f,cnt); 66 printf("%d\n",f[ind[x]]); 67 return 0; 68 }