bzoj 3992 [SDOI2015] 序列统计 —— NTT (循环卷积+快速幂)
题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3992
(学习NTT:https://riteme.github.io/blog/2016-8-22/ntt.html
https://www.cnblogs.com/Mychael/p/9297652.html
http://blog.miskcoo.com/2015/04/polynomial-multiplication-and-fast-fourier-transform#i-15 )
首先,如果把方案数和乘积分别放在系数和次数上,就可以用多项式做了;
方案数放在系数上好说,但次数是相加的,如何表示乘积?
考虑乘积与加法的关系 —— 幂的相乘就是指数相加;
所以可以找出乘积的模数 m 的原根,用其次数相加代表乘积,这个次数好像被称为“指标”;
构造出多项式,由于要取模,所以用 NTT 做;
也就是要把初始的多项式做 n 次幂,可以用快速幂,但注意累乘起来的是系数而不是点值;
指标从0开始或从1开始都可以,也就是把 0 次方作为 1 和把 m-1 次方作为 1 的区别,对应系数的时候要根据这个注意一下(代码中注释里的方案也可);
初始化一个多项式并不是把每个系数都赋值1!而只有第0项是1,这样别的多项式乘过来还是那个多项式;
然后要特别注意读入时去掉0!因为原根系列中没有模出0的,所以以原根为基础的 NTT 算的时候不能考虑0,而反正最后要求的方案中,x >= 1,一旦有0,乘积就是0了,所以0对答案没有影响,就当没给这个数算即可;
一下午的心血...
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=(1<<14),xm=8005,mod=1004535809; int n,m,rev[xn],g,a[xn],b[xn],lim,r[xm],cnt,pri[xm],inv; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return f?ret:-ret; } int pw(ll a,int b,int md) { ll ret=1; for(;b;b>>=1,a=(a*a)%md)if(b&1)ret=(ret*a)%md; return ret; } void div(int x) { for(int i=2;i*i<=x;i++) { if(x%i)continue; pri[++cnt]=i; while(x%i==0)x/=i; } if(x>1)pri[++cnt]=x; } void init() { lim=1; int l=0; while(lim<=m+m)lim<<=1,l++; //while(lim<=2*(m-1))lim<<=1,l++; for(int i=0;i<lim;i++) rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1))); inv=pw(lim,mod-2,mod); if(m==2){g=1; return;} div(m-1); for(g=2;;g++) { bool f=0; for(int j=1;j<=cnt;j++) if(pw(g,(m-1)/pri[j],m)==1){f=1; break;} if(!f)break; } for(int i=1,k=g;i<m;i++,k=(ll)k*g%m)r[k]=i; //for(int i=0,k=1;i<m-1;i++,k=(ll)k*g%m)r[k]=i;//k=1 } int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} void ntt(int *a,int tp) { for(int i=0;i<lim;i++) if(i<rev[i])swap(a[i],a[rev[i]]); for(int mid=1;mid<lim;mid<<=1) { int wn=pw(3,(mod-1)/(mid<<1),mod); if(tp==-1)wn=pw(wn,mod-2,mod);// for(int j=0,len=(mid<<1);j<lim;j+=len) { int w=1; for(int k=0;k<mid;k++,w=(ll)w*wn%mod) { int x=a[j+k],y=(ll)w*a[j+mid+k]%mod; a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y); } } } } void pww() { ntt(a,1); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*a[i]%mod; ntt(a,-1); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*inv%mod; for(int i=m;i<lim;i++)a[i%m+1]=upt(a[i%m+1]+a[i]),a[i]=0;//%m+1 //for(int i=m-1;i<lim;i++)a[i%(m-1)]=upt(a[i%(m-1)]+a[i]),a[i]=0; } void mul() { ntt(a,1); ntt(b,1);// for(int i=0;i<lim;i++)b[i]=(ll)b[i]*a[i]%mod; ntt(a,-1); ntt(b,-1);// for(int i=0;i<lim;i++) a[i]=(ll)a[i]*inv%mod,b[i]=(ll)b[i]*inv%mod; for(int i=m;i<lim;i++) a[i%m+1]=upt(a[i%m+1]+a[i]),a[i]=0, b[i%m+1]=upt(b[i%m+1]+b[i]),b[i]=0; /* for(int i=m-1;i<lim;i++) a[i%(m-1)]=upt(a[i%(m-1)]+a[i]),a[i]=0, b[i%(m-1)]=upt(b[i%(m-1)]+b[i]),b[i]=0; */ } int main() { n=rd(); m=rd(); init(); int p=rd(),num=rd(); for(int i=1,x;i<=num;i++) { x=rd(); if(x)a[r[x]]=1;//x!=0 !! } int t=n; //for(int i=0;i<lim;i++)b[i]=1; b[0]=1;//! for(;t;t>>=1,pww())if(t&1)mul(); printf("%d\n",b[r[p]]); return 0; }