再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
写在前面
为了不使篇幅过长,预计将把基于论文的学习笔记分为三部分:
- DFT,IDFT,FFT的定义,实现与证明:快速傅里叶变换(FFT)学习笔记(其一)
- NTT的实现与证明:快速傅里叶变换(FFT)学习笔记(其二)
- 任意模数NTT与FFT的优化技巧
一些约定
- \([p(x)]=\begin{cases}1,p(x)为真 \\ 0,p(x)为假 \end{cases}\)
- 本文中序列的下标从0开始
- 若\(s\)是一个序列,\(|s|\)表示\(s\)的长度
- 若大写字母如\(F(x)\)表示一个多项式,那么对应的小写字母如\(f\)表示多项式的每一项系数,即\(F(x)=\sum_{i=0}^{n-1} f_ix^i\)
循环卷积
DFT卷积的本质
考虑在(其一)中提到的卷积的定义式。
我们一般做FFT时忽略了式子中的\(\bmod\),其实它是在\(\bmod 2^q\)的意义下的循环卷积,只是因为\(|a|,|b|,|c|<2^q\),所以取不取模都没什么影响。
如果序列长度\(n\)是2的整数次幂,那么直接做就可以了。
如果序列长度\(n\)不是2的整数次幂考虑暴力的做法:先做一次普通FFT,再把\(c_{k+n}\)加到\(c_k\)上。但是这样在做多次FFT时就必须一次一次做,比如多项式快速幂。下面给出了一种在\(O(n \log n)\)的时间内实现任意长度循环卷积的算法:Bluestein’s Algorithm
Bluestein’s Algorithm
注:原论文的推导可能有误
考虑DFT的式子
不妨设
\(x_j=a_j \omega_n^{\frac{j^2}{2}}=a_j(\cos\frac{j^2\pi}{n}+ \text{i}\sin{\frac{j^2\pi}{n}})\)
\(y_j=\omega_n^{-\frac{j^2}{2}}= \cos \frac{\pi j^2}{n}-\text{i}\sin \frac{\pi j^2}{n}\)
那么\(a_i'=\omega_n^{\frac{j^2}{2}}\sum_{j=0}^{n-1} x_j y_{i-j}\)
这已经很类似卷积的形式了,但是注意到\(j\)的上界是\(n-1\)而不是\(i\),\(j-i\)可能为负数。那么我们把\(y\)数组的长度扩大到\(2n\),定义:
\(y_j=\omega_n^{-\frac{(j-n)^2}{2}}= \cos \frac{\pi (j-n)^2}{n}-\text{i}\sin \frac{\pi (j-n)^2}{n}\).
这样\(j<n\)的时候就对应了\(j-i\)为负数的情形,\(j\geq n\)就对应了\(j-i\)为正的情形。然后对\(x\)和\(y\)用一般的FFT,最后的答案存储在\(i+n\)的位置上,也就是说真正的\(a'_i\)实际上对应了乘积结果的\((x \cdot y)_{i+n}\)
这样,我们就只做了3次FFT就求出了任意长度循环DFT。逆变换同理,只是换成共轭复数。注意到在上述的推导中我们没有用到单位根\(\omega\)的任何性质,因此这里的\(\omega\)可以换成任意复数\(z\),这样的变换称为Chirp Z-Transform,CZT.可见,CZT实际上是DFT的广义形式。
代码实现:
//com是手写复数类,省略 void fft(com *x,int *rev,int n,int type){ //为节约篇幅,fft部分省略,x为系数序列,rev为反转数组,n为长度,type=1表示DFT,type=-1表示IDFT } void bluestein(com *a,int n,int type){ //a为系数序列,n为长度,type=1表示DFT,type=-1表示IDFT static com x[maxn*4+5],y[maxn*4+5]; static int rev[maxn*4+5]; memset(x,0,sizeof(x)); memset(y,0,sizeof(y)); //FFT前的预处理 int N=1,L=0; while(N<n*4){ L++; N*=2; } for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); //x[i],y[i]的定义见上式 for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i]; for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n)); fft(x,rev,N,1); fft(y,rev,N,1); for(int i=0;i<N;i++) x[i]*=y[i]; fft(x,rev,N,-1); for(int i=0;i<n;i++){ a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//记得乘上常数 if(type==-1) a[i]/=n;//一定记得除以n,因为做一次Bluestein相当于一次FFT,IFFT最后要除n,这里也要除n } }
例题
[POJ 2821]TN's Kindom III(任意长度循环卷积的Bluestein算法)
分治FFT
一般我们用FFT的时候,序列的所有元素都已知。但是,如果序列本身是根据卷积定义的,就无法直接套FFT
举一个最简单的例子\(f_i =\sum_{j=1}^i f_{i-j}g_j\).其中\(g\)给定,求\(f\). 由于我们卷积的时后后面的数基于前面的数,无法快速计算,时间复杂度退化到\(O(n^2)\). (虽然这个式子可以用(其四)中将会提到的多项式求逆解决,但是分治FFT更通用,可以处理很复杂的式子)
考虑分治: 设当前分治区间为\([l,r]\),假设我们求出了\([l,mid]\)的答案,那么可以求出这些点对\([mid+1,r]\)的影响。那么右半边的点\(x \in [mid+1,r]\)得到的贡献是\(\Delta_x=\sum_{i=l}^{mid} f_i g_{x-i}\).只需要把下标偏移一下(如\([l,mid]\)偏移成\([0,mid-l]\),就是一个卷积的形式,可以运用FFT或NTT计算,计算完之后,把答案累加到数组上.
伪代码如下:
poly f,g;//上述的f,g procedure calc(L,mid,R){ for i in [L,mid] : a[i-L] <- f[i]//下标偏移 for i in [1,R-L] : b[i-1] <- g[i] a <- mul(a,b);//fft或ntt做多项式乘法 for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加贡献 } procedure solve(l,mid){ if(l==r) return; mid <- (l+r)/2 solve(l,mid); calc(l,mid,r); solve(mid+1,r) }
时间复杂度分析:
\(T(n)=2T(\frac{n}{2})+n \log_2n\), 总复杂度\(\Theta(n \log^2n)\)
下面是基于NTT的模板代码(Luogu 4721)
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define maxn 300000 #define G 3 #define invG 332748118 #define inv2 499122177 #define mod 998244353 using namespace std; typedef long long ll; inline ll fast_pow(ll x,ll k){ ll ans=1; while(k){ if(k&1) ans=ans*x%mod; x=x*x%mod; k>>=1; } return ans; } inline ll inv(ll x){ return fast_pow(x,mod-2); } void NTT(ll *x,int n,int type){ static int rev[maxn+5]; int tn=1; int k=0; while(tn<n){ tn*=2; k++; } for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1)); for(int i=0;i<n;i++){ if(i<rev[i]) swap(x[i],x[rev[i]]); } for(int len=1;len<n;len*=2){ int sz=len*2; ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz); for(int l=0;l<n;l+=sz){ int r=l+len-1; ll gnk=1; for(int i=l;i<=r;i++){ ll tmp=x[i+len]; x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod; x[i]=(x[i]+gnk*tmp%mod)%mod; gnk=gnk*gn1%mod; } } } if(type==-1){ int invsz=inv(n); for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod; } } void mul(ll *a,ll *b,ll *ans,int sz){ NTT(a,sz,1); NTT(b,sz,1); for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod; NTT(ans,sz,-1); } void cdq_divide(ll *f,ll *g,int l,int r){ static ll tmpa[maxn+5],tmpb[maxn+5]; if(l==r) return; int mid=(l+r)>>1; cdq_divide(f,g,l,mid); int tn=1,k=0; while(tn<r-l){ k++; tn*=2; } for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0; for(int i=l;i<=mid;i++) tmpa[i-l]=f[i]; for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i]; mul(tmpa,tmpb,tmpa,tn); for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod; cdq_divide(f,g,mid+1,r); } int n; ll f[maxn+5],g[maxn+5]; int main(){ scanf("%d",&n); for(int i=1;i<n;i++) scanf("%lld",&g[i]); f[0]=1; cdq_divide(f,g,0,n-1); for(int i=0;i<n;i++) printf("%lld ",f[i]); }
容易发现,许多dp方程都有分治FFT的形式。对于此类dp方程,我们可以用分治FFT将转移复杂度由\(O(n^2)\)降到\(O(n \log^2 n)\)
例题
[Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)
FFT的弱常数优化
下面介绍一些优化FFT的常数的技巧。虽然这些技巧都只是对FFT的一些小优化,但是在某些题目中优化效果极其明显。
复杂算式中减少FFT次数
如果我们要计算一个复杂的多项式,如\(A(x)=B(x)C(x)+D(x)E(x)\)
最简单的方法是分别计算\(B(x)C(x)\)和\(D(x)E(x)\),这样需要做6次FFT. 但是如果先对\(B,C,D,E\)做DFT,然后直接用点值表达式计算\(a_i=b_ic_i+d_ie_i\),再把\(a\)IDFT回去。这样只需要做5次FFT,且多项式越复杂,这样的常数就越优秀。
例题
[BZOJ 3771] Triple(FFT+容斥原理+生成函数)
利用循环卷积
考虑对于两个长度为\(n\)的序列\(a,b\),计算它们的卷积\(c\)的第\(0.5n\)项到第\(1.5n\)项。传统的方法是补0扩充到\(2n\)的序列。但是因为FFT求得实际上是我们已经提到过的循环卷积,所以如果只补0到\(1.5n\)(上取整),对第\(0.5n\)项到第\(1.5n\)项无影响
在基于牛顿迭代的算法中,能起到较明显的优化作用。会在(其四)中详细介绍这些算法。
小范围暴力
由于FFT的常数较大。在数据范围较小的时候甚至不如\(O(n^2)\)的暴力卷积的优秀。因此在做多次FFT和分治FFT的时候,如果当前的序列长度较小,可以采用暴力算法。
例题
[BZOJ 3509] [CodeChef] COUNTARI (FFT+分块)
快速幂乘法次数的优化
这个东西实际上比较鸡肋。因为多项式快速幂可以通过多项式\(\ln\)和\(\exp\)优化到\(O(n \log n)\).但是为了应对考场上时间不够的情况,我们来考虑如何通过简单的实现来减少\(O(n \log^2n)\)的倍增快速幂的复杂度。
倍增法的思路是根据前面算过的乘积快速算出当前的乘积,如\(1 \to 2 \to 4 \to 8\).最坏情况下需要\(2 \log_2n+C\)次乘法。但这并不是下界。我们定义additional chain为一条链,最开始是1,后一个数减前一个数的差是链上这个是前面的某一个数。例如\(1 \to 2 \to 4 \to 6\).\(6-4=2\)在前面出现过,\(4-2=2\)在前面出现过。那么根据这条additional chain计算6次幂的时候,可以从1次幂出发,用1次幂乘1次幂得到2次幂,再乘2次幂得到4次幂,再乘2次幂得到6次幂。
很可惜,对于数\(k\)求出得到\(k\)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基于BFS。每次我们对于队头的数\(x\),枚举它对应的additional chain中的数\(y\),如果\(x+y\)还没有访问过那么将其入队,并将\(x\)对应的链后面接上\(x+y\). 这个预处理是\(O(k)\)的,且对快速幂的常数优化很显著。
如果\(k\)很大,比如\(10^{10000}\),可以采用十进制快速幂。但是用Method of Four Russians(俗称四毛子算法),可以将乘法次数减少到\(\log_2n+O(\frac{\log n}{\log \log n})\).具体方法见2017年国家集训队论文《非常规大小分块算法初探》
FFT的强常数优化
FFT的强常数优化一般是通过减少FFT次数来实现的
在这一节中,我们记\(DFT(A(x))\)表示多项式\(A(x)\)(或序列)做DFT之后的结果,\(IDFT(A(x))\)同理
我们现在考虑最常见的一个模型:给出两个长度为\(n+1\)和\(m+1\)的多项式\(A(x),B(x)\),我们要计算他们的线性卷积。假设长度已经补齐为第一个大于\(n+m+1\)的2的整数幂\(L\)。
显然直接搞需要3次长度为\(L\)的FFT。毒瘤的Vladimir Smykalov在cf上最先给出了这个问题的优化算法。
DFT的合并
DFT的合并是指,对于两个序列\(a\),\(b\),我们只通过一次FFT就求出\(DFT(a),DFT(b)\)
不妨设:
接下来我们开始推导公式。注意为了简洁,我们记\(X=\frac{2 \pi jk}{2L}\),\(\text{conj}(z)\)表示\(z\)的共轭复数
也就是说,只要一次DFT算出\(DFT(p)\),就可以把序列反转再取共轭复数得到\(DFT(q)\).
由于DFT是线性变换,
其中\(j\)为\(k\)翻转后的数,即\(j=\begin{cases}0,k=0 \\ L-k ,k>0 \end{cases}\)
又由\((4.1),(4.2)\)式
这样我们就可以从\(q'\)推出\(a',b'\),也就是说一次DFT就能得到\(a'\)和\(b'\)了.
我们一共做了2次长度为\(L\)的FFT.
代码(UOJ#34):
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define maxn 1000000 const double pi=acos(-1.0); using namespace std; typedef long long ll; struct com{ double real; double imag; com(){ } com(double _real,double _imag){ real=_real; imag=_imag; } com(double x){ real=x; imag=0; } void operator = (const com x){ this->real=x.real; this->imag=x.imag; } void operator = (const double x){ this->real=x; this->imag=0; } friend com operator + (com p,com q){ return com(p.real+q.real,p.imag+q.imag); } friend com operator + (com p,double q){ return com(p.real+q,p.imag); } void operator += (com q){ *this=*this+q; } void operator += (double q){ *this=*this+q; } friend com operator - (com p,com q){ return com(p.real-q.real,p.imag-q.imag); } friend com operator - (com p,double q){ return com(p.real-q,p.imag); } void operator -= (com q){ *this=*this-q; } void operator -= (double q){ *this=*this-q; } friend com operator * (com p,com q){ return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real); } friend com operator * (com p,double q){ return com(p.real*q,p.imag*q); } void operator *= (com q){ *this=(*this)*q; } void operator *= (double q){ *this=(*this)*q; } friend com operator / (com p,double q){ return com(p.real/q,p.imag/q); } void operator /= (double q){ *this=(*this)/q; } com conj(){ return com(real,-imag); } void print(){ printf("%lf + %lf i ",real,imag); } }; int rev[maxn+5]; com w[maxn+5]; void fft(com *x,int n){ for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); for(int len=1;len<n;len*=2){ int sz=len*2; for(int l=0;l<n;l+=sz){ int r=l+len-1; for(int i=l;i<=r;i++){ com tmp=x[i+len]; x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k) x[i]=x[i]+tmp*w[n/sz*(i-l)]; } } } } void mul(ll *a,ll *b,ll *c,int n){ static com p[maxn+5],r[maxn+5]; for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//预处理单位根 for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i] fft(p,n); for(int i=0;i<n;i++){ int j=(i>0?(n-i):0);//0的位置需要特判一下 com q=p[j]; r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子 } fft(r,n);//这里是用了第一篇中提到的反转技巧 for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5; } int n,m; ll a[maxn+5],b[maxn+5],c[maxn+5]; int main(){ scanf("%d %d",&n,&m); for(int i=0;i<=n;i++) scanf("%lld",&a[i]); for(int i=0;i<=m;i++) scanf("%lld",&b[i]); int N=1,L=0; while(N<n+m+1){ L++; N*=2; } for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); mul(a,b,c,N); for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]); }
IDFT的合并
IDFT的合并是指,对于两个序列\(a\),\(b\),我们只通过一次FFT就求出\(IDFT(a),IDFT(b)\)
IDFT的合并非常简单。
设\(r(x)=a(x)+\text{i}b(x)\)
由于IDFT是线性变换
\(IDFT(r(x))=IDFT(a(x))+\text{i}IDFT(b(x))\)
又因为\(a(x)\)和\(b(x)\)都是实数序列,那么\(IDFT(r(x))\)的实部就是\(IDFT(a(x))\),虚部就是\(IDFT(b(x))\)
形如\((A+B)(C+D)\)的卷积的优化
在这一节中我们讨论\((A(x)+B(x))(C(x)+D(x))\)形式的卷积的优化.
一般的做法是对\(A,B,C,D\)都做一次DFT,然后按照这个式子直接计算,最后再IDFT回来。需要5次FFT.
而根据上面的合并技巧,先把\(A(x),B(x)\)合并DFT,\(C(x),D(x)\)合并DFT得到点值表达式.
由于\((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x)\)
我们可以直接把点值表达式相乘得到这4个多项式。对于这4个多项式,分成2组合并做IDFT即可。
总共需要4次FFT.
大致代码如下:
void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){ static com p[maxn+5],q[maxn+5]; static com r[maxn+5],s[maxn+5]; for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n)); for(int i=0;i<n;i++){ p[i]=com(a[i],b[i]);//打包A,B q[i]=com(c[i],d[i]);//打包C,D } fft(p,n); fft(q,n); for(int i=0;i<n;i++){ int j=(i==0?0:n-i); //得到DFT(A),DFT(B),DFT(C),DFT(D) com da=(p[i]+p[j].conj())*0.5; com db=(p[i]-p[j].conj())*com(0,-0.5); com dc=(q[i]+q[j].conj())*0.5; com dd=(q[i]-q[j].conj())*com(0,-0.5); r[j]=da*dc+da*dd*com(0,1);//打包AC,AD s[j]=db*dc+db*dd*com(0,1); //打包BC,BD } fft(r,n); fft(s,n); for(int i=0;i<n;i++){ ll ac,ad,bc,bd; ac=(ll)(r[i].real/n+0.5); ad=(ll)(r[i].imag/n+0.5); bc=(ll)(s[i].real/n+0.5); bd=(ll)(s[i].imag/n+0.5); ans[i]=ac+ad+bc+bd; } }
卷积的终极优化
上述优化中我们只用到了DFT的思想。现在我们利用FFT的思想继续优化
同样拆分奇偶项,\(A(x)=A_0(x^2)+xA_1(x^2)\)
我们只需要知道上式中\(x^0,x^1,x^2\)的系数
发现\(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2)\)是奇数项的系数,\(A_0(x^2)B_0(x^2)\)和\(A_1(x^2)B_1(x^2)\)是偶数项的系数,而偶数项的两个东西都可以看成一个关于\(x^2\)的多项式。
我们先优化DFT的过程,观察\((4.6)\)式的乘积形式\((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\).
我们发现,这个形式和上一节的\((A+B)(C+D)\)很像,可以类似地优化。
令\(p_k={a_0}_k+\text{i}{a_1}_k,q_k={b_0}_k+\text{i}{b_1}_k\)
然后合并IDFT,再设两个辅助多项式
(注意我们把\(x^2\)换元成\(x\),做DFT的时候要乘上单位根)
那么我们只需要计算出\(IDFT(G(x))\)和\(IDFT(F(x))\)
设\(R(x)=G(x)+\mathrm{i} F(x)\)
那么因为IDFT是线性变换,\(IDFT(R(x))=IDFT(G(x))+\mathrm{i} IDFT(F(x))\)
(IDFT的线性性这里不做证明,容易发现两个点值表达式相加再IDFT回来,显然系数也会相加)
显然这两个多项式IDFT的结果是实数。故我们只要求出\(IDFT(R(x))\),每一项系数的实部就是偶数项系数\(G(x)\),虚部就是奇数项系数\(F(x)\)
我们再考虑把合并DFT弄进去,即式\((4.3)(4.4)(4.5)\)
接下来我们尝试用\(DFT(p_k),DFT(q_k)\)来表示\(R(x)=G(x)+\text{i}F(x)\),为了推导简洁,我们省略\(DFT\)不写
那么
和上一节的\((A+B)(C+D)\)不同,我们只用了3次长度为\(L/2\)的FFT,就求出了答案,这是由于FFT本身的性质。因为长度缩减了一半,我们不妨称它为\(1.5\)次FFT.
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define maxn 1000000 const double pi=acos(-1.0); using namespace std; typedef long long ll; struct com{ double real; double imag; com(){ } com(double _real,double _imag){ real=_real; imag=_imag; } com(double x){ real=x; imag=0; } void operator = (const com x){ this->real=x.real; this->imag=x.imag; } void operator = (const double x){ this->real=x; this->imag=0; } friend com operator + (com p,com q){ return com(p.real+q.real,p.imag+q.imag); } friend com operator + (com p,double q){ return com(p.real+q,p.imag); } void operator += (com q){ *this=*this+q; } void operator += (double q){ *this=*this+q; } friend com operator - (com p,com q){ return com(p.real-q.real,p.imag-q.imag); } friend com operator - (com p,double q){ return com(p.real-q,p.imag); } void operator -= (com q){ *this=*this-q; } void operator -= (double q){ *this=*this-q; } friend com operator * (com p,com q){ return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real); } friend com operator * (com p,double q){ return com(p.real*q,p.imag*q); } void operator *= (com q){ *this=(*this)*q; } void operator *= (double q){ *this=(*this)*q; } friend com operator / (com p,double q){ return com(p.real/q,p.imag/q); } void operator /= (double q){ *this=(*this)/q; } com conj(){ return com(real,-imag); } void print(){ printf("%lf + %lf i ",real,imag); } }; int rev[maxn+5]; com w[maxn+5]; void fft(com *x,int n){ for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); for(int len=1;len<n;len*=2){ int sz=len*2; for(int l=0;l<n;l+=sz){ int r=l+len-1; for(int i=l;i<=r;i++){ com tmp=x[i+len]; x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k) x[i]=x[i]+tmp*w[n/sz*(i-l)]; } } } } void mul(ll *a,ll *b,ll *c,int n){ static com p[maxn+5],q[maxn+5],r[maxn+5]; for(int i=0;i<n;i++){//合并做DFT if(i%2==1){ p[i/2].imag=a[i]; q[i/2].imag=b[i]; }else{ p[i/2].real=a[i]; q[i/2].real=b[i]; } } n/=2; for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n)); fft(q,n); fft(p,n); for(int i=0;i<n;i++){ int j=(i>0?(n-i):0); r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25; } fft(r,n); for(int i=0;i<n;i++){ c[i*2]=r[i].real/n+0.5; c[i*2+1]=r[i].imag/n+0.5; } } int n,m; ll a[maxn+5],b[maxn+5],c[maxn+5]; int main(){ scanf("%d %d",&n,&m); for(int i=0;i<=n;i++) scanf("%lld",&a[i]); for(int i=0;i<=m;i++) scanf("%lld",&b[i]); int N=1,L=0; while(N<=n+m+1){ L++; N*=2; } for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意这里的rev数组是对N/2做的,L要-1 mul(a,b,c,N); for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]); }
任意模数NTT
三模数NTT
这是任意模数NTT的算法中最好理解的一种,它基于中国剩余定理。
定理5.1 若\(m_1,m_2 ,\dots m_n\)两两互质,则对于\(\forall a_1,a_2 \dots a_n\)同余方程组
\[\begin{cases} x \equiv a_1 (\bmod m_1) \\ x \equiv a_2 (\bmod m_2) \\ \dots \\ x \equiv a_n (\bmod m_n)\end{cases} \]有整数解解,且可以用如下方式构造解
- 设\(M=\prod_{i=1}^n m_i,M_i=\frac{M}{m_i}\)
- 设\(M_i^{-1}\)为模\(m_i\)意义下\(M_i\)的逆元
- 则该方程组在模\(M\)意义下的唯一解为\(x=\sum_{i=1}^n a_iM_iM_i^{-1}\) ,方程组的通解可以表示为\(x+kM(k \in \mathbb{Z})\)
这就是著名的中国剩余定理(Chinese Reminder Theorem,CRT)
证明:
对于\(k \neq i\),\(a_iM_iM_i^{-1} \bmod m_k=0\), 而根据逆元的定义,\(a_iM_iM_i^{-1} \bmod m_i =a_i\). 再代入到\(\sum_{i=1}^n a_iM_iM_i^{-1}\),原方程组成立。
回到任意模数NTT问题
模\(M\)意义下长度为\(n\)的序列做卷积,最大值可以到\(n^2M\).一般的题目中\(n \leq 10^5,M\leq 10^{9}\),那么结果会到\(10^{23}\)级别。用long double
等存储会丢失精度。那么我们可以选三个乘起来大于\(10^{23}\)的NTT模数998244353,1004535809,469762049(选这三个模数的好处是他们的原根都是3,所以NTT部分写起来比较简洁)。然后分别在这三个模数的意义下做卷积。最后考虑把答案合并,我们只考虑某一位上的值\(ans\),容易写出:
显然\(m_1,m_2,m_3\)互质,那么我们可以利用中国剩余定理直接合并。但是,直接合并把三个模数乘起来的时候会超出long long
的范围。注意到两个模数相乘还是在long long
范围内的,可以两两合并,具体方法如下,
记\(inv(a,m)\)表示\(a\)在模\(m\)下的逆元.根据CRT合并\((5.2)(5.3)\)有:
不妨设\(ans=km_1m_2+r\),根据\(5.4\)有
\(ans=km_1 m_2+r=q m_3+a_3 \tag{5.6}\),
在模 \(m_3\) 意义下有
\(km_1 m_2+r \equiv a_3 (\bmod m_3) \tag{5.7}\)
因此\(k=(a_3-r_2)inv(m_1m_2,m_3) (\bmod m_3)\),不妨设\(k=dm_3+e\),代入\(5.6\)得
由于\(m_1m_2m_3>ans\),所以\(d=0\),也就是说,\(ans=em_1m_2+r\),其中\(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3)\)
const ll mm=m1*m2; inline ll inv(ll a,ll m); ll mul(ll a,ll b,ll m);//要用按位乘防止溢出 ll CRT(ll a1,ll a2,ll a3){ ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm; ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3; return ((e%C)*(mm%C)%C+r%C)%C; }
完整代码(LuoguP4245 【模板】任意模数NTT)
#include<iostream> #include<cstdio> #include<cstring> #define m1 998244353ll #define m2 1004535809ll #define m3 469762049ll #define G 3 #define maxn 1048576 using namespace std; typedef long long ll; const ll mm=m1*m2; ll C; ll fast_pow(ll x,ll k,ll m){ ll ans=1; while(k){ if(k&1) ans=ans*x%m; x=x*x%m; k>>=1; } return ans; } inline ll inv(ll a,ll m){ return fast_pow(a%m,m-2,m); //一定要取模m } ll mul(ll a,ll b,ll m){ ll ans=0; while(b){ if(b&1) ans=(ans+a)%m; a=(a+a)%m; b>>=1; } return ans; } ll CRT(ll a1,ll a2,ll a3){ //[Warning]You are not expected to understand this. ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm; ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3; return ((e%C)*(mm%C)%C+r%C)%C; } int n,m,N,L; int rev[maxn+5]; void NTT(ll *x,int n,int type,ll mod){ ll invG=inv(G,mod); for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); for(int len=1;len<n;len*=2){ int sz=len*2; ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod); for(int l=0;l<n;l+=sz){ int r=l+len-1; ll gnk=1; for(int i=l;i<=r;i++){ ll tmp=x[i+len]; x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod; x[i]=(x[i]+gnk*tmp%mod)%mod; gnk=gnk*gn1%mod; } } } if(type==-1){ ll invn=inv(n,mod); for(int i=0;i<n;i++) x[i]=x[i]*invn%mod; } } void fmul(ll *a,ll *b,ll *ans,int n,ll mod){ static ll ta[maxn+5],tb[maxn+5]; for(int i=0;i<n;i++) ta[i]=a[i]; for(int i=0;i<n;i++) tb[i]=b[i]; NTT(ta,n,1,mod); if(a!=b) NTT(tb,n,1,mod); for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod; NTT(ans,n,-1,mod); } ll a[maxn+5],b[maxn+5],c[3][maxn+5]; int main(){ scanf("%d %d %lld",&n,&m,&C); for(int i=0;i<=n;i++){ scanf("%lld",&a[i]); a[i]%=C; } for(int i=0;i<=m;i++){ scanf("%lld",&b[i]); b[i]%=C; } N=1,L=0; while(N<n+m+1){ N*=2; L++; } for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); fmul(a,b,c[0],N,m1); fmul(a,b,c[1],N,m2); fmul(a,b,c[2],N,m3); for(int i=0;i<n+m+1;i++){ printf("%lld ",CRT(c[0][i],c[1][i],c[2][i])); } }
容易发现,三模数NTT需要9次FFT,不是很优秀
拆系数FFT
我们之前讨论的优化都是针对FFT的,那不妨尝试用FFT解决任意模数NTT
最简单的想法是不取模,FFT完再取模。但是上文提到数值过大,long double
会丢失精度。
int128
是一个方法,但在OI比赛中不一定能使用。所以需要拆系数。
设\(M_0=[\sqrt{M}]\)
相当于把模数换成\(M_0\),降低大小。
代入对应的多项式
这不就是我们提到的\((A+B)(C+D)\)形的卷积吗?
由于\(k,b\)都不超过\(2^{15}\),于是就不容易被卡精度了。实际操作中我们不必取\(M_0=\sqrt{M}\),直接取\(M_0=2^{15}\)即可。这样取模运算可以换成位运算,进一步减小常数。
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define maxn 1000000 const double pi=acos(-1.0); using namespace std; typedef long long ll; struct com{ double real; double imag; com(){ } com(double _real,double _imag){ real=_real; imag=_imag; } com(double x){ real=x; imag=0; } void operator = (const com x){ this->real=x.real; this->imag=x.imag; } void operator = (const double x){ this->real=x; this->imag=0; } friend com operator + (com p,com q){ return com(p.real+q.real,p.imag+q.imag); } friend com operator + (com p,double q){ return com(p.real+q,p.imag); } void operator += (com q){ *this=*this+q; } void operator += (double q){ *this=*this+q; } friend com operator - (com p,com q){ return com(p.real-q.real,p.imag-q.imag); } friend com operator - (com p,double q){ return com(p.real-q,p.imag); } void operator -= (com q){ *this=*this-q; } void operator -= (double q){ *this=*this-q; } friend com operator * (com p,com q){ return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real); } friend com operator * (com p,double q){ return com(p.real*q,p.imag*q); } void operator *= (com q){ *this=(*this)*q; } void operator *= (double q){ *this=(*this)*q; } friend com operator / (com p,double q){ return com(p.real/q,p.imag/q); } void operator /= (double q){ *this=(*this)/q; } com conj(){ return com(real,-imag); } void print(){ printf("(%lf,%lf)\n",real,imag); } }; int rev[maxn+5]; com w[maxn+5]; void fft(com *x,int n){ for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); for(int len=1;len<n;len*=2){ int sz=len*2; for(int l=0;l<n;l+=sz){ int r=l+len-1; for(int i=l;i<=r;i++){ com tmp=x[i+len]; x[i+len]=x[i]-tmp*w[n/sz*(i-l)]; x[i]=x[i]+tmp*w[n/sz*(i-l)]; } } } } ll mod; void mul(ll *ina,ll *inb,ll *inc,int n){ static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5]; static com p[maxn+5],q[maxn+5]; static com r[maxn+5],s[maxn+5]; for(int i=0;i<n;i++){ ina[i]=(ina[i]+mod)%mod; inb[i]=(inb[i]+mod)%mod; a[i]=ina[i]>>15; b[i]=ina[i]&((1<<15)-1); c[i]=inb[i]>>15; d[i]=inb[i]&((1<<15)-1); } for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n)); for(int i=0;i<n;i++){ p[i]=com(a[i],b[i]);//打包A,B q[i]=com(c[i],d[i]);//打包C,D } fft(p,n); fft(q,n); for(int i=0;i<n;i++){ // p[i].print(); int j=(i==0?0:n-i); //得到DFT(A),DFT(B),DFT(C),DFT(D) com da=(p[i]+p[j].conj())*0.5; com db=(p[i]-p[j].conj())*com(0,-0.5); com dc=(q[i]+q[j].conj())*0.5; com dd=(q[i]-q[j].conj())*com(0,-0.5); r[j]=da*dc+da*dd*com(0,1);//打包AC,AD s[j]=db*dc+db*dd*com(0,1); //打包BC,BD } fft(r,n); fft(s,n); for(int i=0;i<n;i++){ ll ac,ad,bc,bd; ac=(ll)(r[i].real/n+0.5)%mod; ad=(ll)(r[i].imag/n+0.5)%mod; bc=(ll)(s[i].real/n+0.5)%mod; bd=(ll)(s[i].imag/n+0.5)%mod; inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod; } } int n,m; ll a[maxn+5],b[maxn+5],c[maxn+5]; int main(){ scanf("%d %d %lld",&n,&m,&mod); for(int i=0;i<=n;i++) scanf("%lld",&a[i]); for(int i=0;i<=m;i++) scanf("%lld",&b[i]); int N=1,L=0; while(N<=n+m+1){ L++; N*=2; } for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); mul(a,b,c,N); for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]); }
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 用 C# 插值字符串处理器写一个 sscanf
· Java 中堆内存和栈内存上的数据分布和特点
· 开发中对象命名的一点思考
· .NET Core内存结构体系(Windows环境)底层原理浅谈
· C# 深度学习:对抗生成网络(GAN)训练头像生成模型
· 趁着过年的时候手搓了一个低代码框架
· 本地部署DeepSeek后,没有好看的交互界面怎么行!
· 为什么说在企业级应用开发中,后端往往是效率杀手?
· 用 C# 插值字符串处理器写一个 sscanf
· 乌龟冬眠箱湿度监控系统和AI辅助建议功能的实现