bzoj3992: [SDOI2015]序列统计
传送门:http://www.lydsy.com:808/JudgeOnline/problem.php?id=3992
思路:M是一个质数,问题又是求乘积,于是我们就可以想到利用M的原根g把问题变成求和(我怎么想不到啊。。。)
根据原根的性质,我们可以把1到M-1中的数i表示为(g^b[i])%M,且指数互不相同
那么X就可以表示成(g^b[x])%M
问题就转化为:然后问题转化成了在序列b中,选出n个数(一个数可以取多次),且它们的和s满足s=b[x]
这个问题我们可以用母函数+NTT解决,答案就是多项式的b[x]次项的系数
因为n很大,所以再套一个快速幂即可
#include<cmath> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int mod=(479<<21)+1,G=3,maxn=17000; int n,k,sum,m,inv_G,inv_N,T[maxn],vis[maxn],tim,N,rev[maxn],pos[maxn],root; int qpow(int a,int b){ int res=1; for (;b;b>>=1){ if (b&1) res=1ll*res*a%mod; a=1ll*a*a%mod; } return res; } bool check(int x){ int now=1;++tim; for (int i=1;i<n;i++,now=now*x%n){ if (vis[now]==tim) return 0; vis[now]=tim; } return 1; } int rever(int x){int res=0,len=N;while (len--) res<<=1,res^=(x&1),x>>=1;return res;} int findroot(){for (int i=2;i<=n;i++) if (check(i)) return i;} struct DFT{ int a[maxn]; void ntt(int op){ for (int i=0;i<N;i++) if (rev[i]>i) swap(a[rev[i]],a[i]); int g=op==1?G:inv_G; for (int sz=2;sz<=N;sz<<=1){ int t=qpow(g,(mod-1)/sz); for (int bg=0;bg<N;bg+=sz) for (int po=bg,w=1;po<bg+(sz>>1);po++){ int x=a[po],y=1ll*a[po+(sz>>1)]*w%mod; a[po]=(x+y)%mod,a[po+(sz>>1)]=(x-y+mod)%mod; w=1ll*w*t%mod; } } if (op==-1) for (int i=0;i<N;i++) a[i]=1LL*a[i]*inv_N%mod; } }a,b; void qpow(){ b.a[0]=1; for (;k;k>>=1){ a.ntt(1); if (k&1){ b.ntt(1);for (int i=0;i<N;i++) b.a[i]=1ll*b.a[i]*a.a[i]%mod; b.ntt(-1);for (int i=N-1;i>=n-1;i--) b.a[i-n+1]=(b.a[i-n+1]+b.a[i])%mod,b.a[i]=0; //因为是在mod m(代码里的n)的条件下进行的,所以要把和超过m的后半部分的答案加到前半部分去,而不是简单的清空 } for (int i=0;i<N;i++) a.a[i]=1ll*a.a[i]*a.a[i]%mod;a.ntt(-1); for (int i=N-1;i>=n-1;i--) a.a[i-n+1]=(a.a[i-n+1]+a.a[i])%mod,a.a[i]=0; } } int main(){ inv_G=qpow(G,mod-2); scanf("%d%d%d%d",&k,&n,&sum,&m); for (int i=1;i<=m;i++) scanf("%d",&T[i]); N=(int)ceil(log2(n))+1; for (int i=0;i<(1<<N);i++) rev[i]=rever(i); N=1<<N,inv_N=qpow(N,mod-2),root=findroot(); for (int i=0,res=1;i<n-1;i++) pos[res]=i,res=res*root%n; for (int i=1;i<=m;i++) if (T[i]) a.a[pos[T[i]]]++; qpow(),printf("%d\n",b.a[pos[sum]]); return 0; }