【SDOI2015】序列统计
题面
https://www.luogu.org/problem/P3321
题解
首先贡献是$f[a_ib_i]+=f1[a_i]\times f2[b_i]$,用原根变成$f[a_i+b_i]+=f1[a_i]\times f2[b_i]$,即形成一个新的映射。
开个桶,即求这个多项式的$n$次幂。
$NTT+$分治快速幂。
自己编的$NTT$口诀:
上倍增,中加二倍,下加加。
上界$1,0,0$,下界$li,li,mi$
原根$3$,$2^{mid}$次单位根$w0=phi(p)/2mid$($FFT$:$w0=\{cos(pi/mid),sin(pi/mid)\times opt\}$)
加($mid$)乘($w$),加($mid$)减($y$)。
$-1$时,$reverse(1,lim)$除以$lim$
NTT(copy from Gloid orz Gloid)
没啥区别就是模意义下的。那用原根代替复数就好了。
原根随便都能求出来。
模数需要为2^n*k+1的形式且为质数。因为需要求得2^i次根,也即原根的(p-1)/2^i次,这个东西显然需要是整数。
常用模数有
998244353=2^23*119+1
1004535809=2^21*479+1
469762049=2^26*7+1
都能跑几百万项。
IDFT时用原根的逆元。最后乘项数的逆元。
#include<cmath> #include<stack> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ri register int #define N 100000 #define mod 1004535809 #define LL long long using namespace std; inline int read() { int ret=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9') ret*=10,ret+=(ch-'0'),ch=getchar(); return ret; } int n,m,x,t; int f[N],s[N]; int a[N],mp[N],b[N],ftr[30]; int lim=1,l=0,r[N],ret[N]; int pow(int a,int b,int p) { int ret=1; for (;b;b>>=1,a=a*1LL*a%p) if (b&1) ret=ret*1LL*a%p; return ret; } int getg(int m) { int tmp=m-1; int c=0; for (ri i=2;i*i<=tmp;i++) if (tmp%i==0) { ftr[++c]=i; while (tmp%i==0) tmp/=i; } if (tmp>1) ftr[++c]=tmp; for (ri g=2;g<=m-1;g++) { bool fl=0; for (ri j=1;j<=c;j++) if (pow(g,(m-1)/ftr[j],m)==1) {fl=1;break;} if (!fl) return g; } return -1; } void NTT(int *p,int opt) { for (ri i=0;i<lim;i++) if (i<r[i]) swap(p[i],p[r[i]]); for (ri i=1;i<lim;i<<=1) { int w0=pow(3,(mod-1)/(i<<1),mod); for (ri j=0;j<lim;j+=2*i) { int w=1; for (ri k=0;k<i;k++,w=w*1LL*w0%mod) { int x=p[j+k],y=p[i+j+k]*1LL*w%mod; p[j+k]=(x+y)%mod; p[i+j+k]=(x-y+mod)%mod; } } } if (opt==-1) { reverse(&p[1],&p[lim]); int inv=pow(lim,mod-2,mod); for (ri i=0;i<lim;i++) p[i]=p[i]*1LL*inv%mod; } } void mul(int *a1,int *a2,int *c) { memset(a,0,sizeof(a)); memset(b,0,sizeof(b)); for (ri i=0;i<m-1;i++) a[i]=a1[i],b[i]=a2[i]; NTT(a,1); NTT(b,1); for (ri i=0;i<lim;i++) a[i]=a[i]*1LL*b[i]%mod; NTT(a,-1); memset(ret,0,sizeof(ret)); for (ri i=0;i<m-1;i++) ret[i]=(a[i]+a[i+m-1])%mod; for (ri i=0;i<m-1;i++) c[i]=ret[i]; } int main() { n=read(); m=read(); x=read(); t=read(); int g=getg(m); for (ri i=0;i<m-1;i++) mp[pow(g,i,m)]=i; while (lim<=2*(m-2)) lim<<=1,l++; for (ri i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); for (ri i=1;i<=t;i++) { int xx=read()%m; if (xx) f[mp[xx]]++; } s[mp[1]]=1; for (;n;n>>=1,mul(f,f,f)) if (n&1) mul(s,f,s); printf("%d\n",s[mp[x]]); }