FFT/NTT 多项式学习笔记
FFT(快速傅立叶变换)和NTT(快速数论变换)看上去很高端,真正搞懂了就很simple了辣。
首先给出多项式的一些定义(初中数学内容):
形如Σaixi的式子就是多项式!
多项式中每个单项式叫做多项式的项。
这些单项式中的最高次数,就是这个多项式的次数。
有几个不同的元也是多项式,但在下面将不被考虑。
注意:(n+1)个点可以唯一确定一个n次多项式(两点定线啊之类的)。
然后就是一些比较高明的东西了。
首先在掌握FFT之前我们要掌握一下知识:
1.复数的计算法则.
形如(a+bi)的数叫复数,分为实部和虚部。
i是这么一个东西:i*i+1=0,虚数单位。
复数的加减法:实部虚部分别相加减。
复数的乘法:(a+bi)*(c+di)=(ac-bd)+(ad+bc)i;
除法太难打所以请戳这里。
2.复数的表达形式
感谢一位叫卜卜的热心网友大晚上不看番教我数学。
第一种形式就是代数式:(a+bi),高中数学内容。
第二种形式也许?叫三角式:r(cosθ+isinθ)。
具体来说,将代数式里的a,b放到二维笛卡尔坐标系平面直角坐标系里,横坐标为实部,纵坐标为虚部。把原点和(a,b)相连,记这条向量与X轴的夹角为θ,模长为r,上面那个式子就很好理解了。
那么来看看三角式下的乘法运算?
r1(cosθ1+isinθ1)*r2(cosθ2+isinθ2) = r1r2(cos(θ1+θ2)+isin(θ1+θ2))
没错就是这样。
于是就有显而易见的n次方式:
(r(cosθ+isinθ))^n=r^n(cos(nθ)+sin(nθ))
这在FFT中会用到。
还有一个公式是(cosθ+isinθ)=eiθ。推理过程要用到。
然后是多项式乘法。一个n次的多项式乘上一个m次的多项式,结果是(n+m)次的。
朴素的多项式相乘时间复杂度是O(n^2)的,不够优秀。
而FFT则是利用了单位复根的优秀性质来解决了这么一个问题。
首先我们需要把多项式转化成点值表示法,称为求值。其逆过程称为插值。
这样有一个好处:
两个多项式A,B分别取点(X,Ya)和(X,Yb),A×B就会取到点(X,Ya*Yb);
具体是什么原因?我认为生命需要留下一点遗憾(啧)。
其实很好理解。
T(x)=f(x)*g(x),所以T(3)=f(3)*g(3)。
显而易见。
所以转化成点值表示法后,"相乘"反倒成为最简单的了。
所以多项式相乘的基本步骤:
对A,B求值 » 点值乘法 » 插值。
若能将求值和插值的复杂度降低,就能达到我们的目的了!
FFT的核心思想:
通过恰当选取x的值,并采用分治策略使得求值和插值的复杂度降下来。
首先我们要了解的是n次单位复数根。
记为Wn...Wnn。Wnn = 1 = 1+0*i;
并且有n次单位复数根的个数为n。
算法导论告诉我们,nn个单位复数根均匀的分布在以复平面的原点为圆心的单位半径的圆周上。
憋问我为什么
(Pic from Xlightgod)
记Wn=r(cosθ+isinθ)。
那么我们可以得知:Wnn=r^n(cos(nθ)+i*sin(nθ))=(1+0*i);
r=1,然后你稍稍推一下就知道θ=0。
设nθ=φ+2kπ,则φ=θ/n+2kπ/n;
因为θ=0,所以就是2kπ/n有值。
所以φ=2kπ/n;
sin和cos都是以2π为周期的,所以可以用φ代替θ
所以Wn=cosφ+isinφ=e(i*2kπ/n)。
接下来就可以证明一个重要的定理:
消去定理:Wakbk=(e(i*2kπ/ak))bk=e(i*2kπ*bk/ak)=e(i*2kπ*b/a)=Wab;
然后用这个定理可以证明:
折半定理:(Wnk)2=e(2*2kπ/n)=e(2kπ/(n/2))=Wn/2k;
这样的话,一次平方下来,取值就少一倍。
接下来就是很简单的 分治 了。
《论折半定理在信息学竞赛中的简单应用》 傅立叶
把这个多项式A(x)=Σaixi分治一下,构建新的多项式。
A[0](x)=a0+a2x+a4x2+...+an-2x(n-2)/2;
A[1](x)=a1+a3x+a5x2+...+an-1x(n-1)/2;
A(x)=A[0](x2)+x*A[1](x2);
因为这个是严格分治的,所以最高次项必须要是2的n次方。
(你问我n不是2的幂怎么办?扩大一下,高位系数全为0不就完了
所以说常数大得要死。
所以我们利用快速傅立叶变换求出了离散傅立叶变换(DFT)。
好像是把求值叫DFT,把插值叫IDFT。
然后又有人证明出插值只要将Wn变成Wn-1,再将结果除以n即可。
再做一遍FFT就可以了。
Congratulations! 复杂度已经被我们降到了O((n+m)log(n+m))。
代码实现起来竟然这么短,关键语句只有9行!
//uoj模板题
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <cstring> #include <queue> #include <cmath> #include <complex> #define LL long long int using namespace std; const int N = 262145; const double pi = acos(-1.0); typedef complex<double> dob; int n,m; dob a[N],b[N]; int gi() { int x=0,res=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();} while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*res; } inline void FFT(dob *A,int len,int f) { if(len==1)return; dob wn(cos(2.0*pi/len),sin(f*2.0*pi/len)),w(1,0),t; dob A0[len>>1],A1[len>>1]; for(int i=0;i<(len>>1);++i)A0[i]=A[i<<1],A1[i]=A[i<<1|1]; FFT(A0,len>>1,f);FFT(A1,len>>1,f); for(int i=0;i<(len>>1);++i,w*=wn){ t=w*A1[i]; A[i]=A0[i]+t; A[i+(len>>1)]=A0[i]-t; } } int main() { n=gi();m=gi(); for(int i=0;i<=n;++i)a[i]=gi(); for(int i=0;i<=m;++i)b[i]=gi(); m+=n; for(n=1;n<=m;n<<=1); FFT(a,n,1);FFT(b,n,1); for(int i=0;i<=n;++i)a[i]*=b[i]; FFT(a,n,-1); for(int i=0;i<=m;++i) printf("%d ",int(a[i].real()/n+0.5)); return 0; }
然而我们早就知道递归有着巨大的常数,加上FFT的巨大常数(三角函数计算),导致奇慢无比。
我们来欣赏一下这个美丽的蝴蝶递归。
把最后一行的数化成二进制:
000,100,010,110,001,101,011,111;
然后把每一个数顺序反过来:
000,001,010,011,100,101,110,111;
是个递增的对不对?十分优美对不对?
优美个鬼啊
于是就有人喜(丧)大(心)普(病)奔(狂)推出了三层for人工合并的东西。
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <cstring> #include <queue> #include <cmath> #include <complex> #define LL long long int using namespace std; const int N = 262145; const double pi = acos(-1.0); typedef complex<double> dob; int n,m,L,R[N]; dob a[N],b[N]; inline int gi() { int x=0,res=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();} while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*res; } inline void FFT(dob *A,int f) { for(int i=0;i<n;++i)if(i<R[i])swap(A[i],A[R[i]]); for(int i=1;i<n;i<<=1){ dob wn(cos(pi/i),sin(f*pi/i)),x,y; for(int j=0;j<n;j+=(i<<1)){ dob w(1,0); for(int k=0;k<i;++k,w*=wn){ x=A[j+k];y=w*A[j+i+k]; A[j+k]=x+y; A[j+i+k]=x-y; } } } } int main() { n=gi();m=gi(); for(int i=0;i<=n;++i)a[i]=gi(); for(int i=0;i<=m;++i)b[i]=gi(); m+=n; for(n=1;n<=m;n<<=1)++L; for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); FFT(a,1);FFT(b,1); for(int i=0;i<=n;++i)a[i]*=b[i]; FFT(a,-1); for(int i=0;i<=m;++i)printf("%d ",int(a[i].real()/n+0.5)); return 0; }
第一层i是枚举合并到了哪一层。
第二层j是枚举合并区间。
第三层k是枚举区间内的下标。
j*k=(n+m);i是log级的。
所以说复杂度没变,常数降下来了。
其实常数还能降一点。
1.勿使用系统的复数库,自己手写结构体,只需重载加减乘即可,大概可以压到原来时间的60%。
2.预处理所有要用到的单位复根及其幂(用三角函数式计算),这样还可以保证精度(cogs 释迦,不这么写必挂),大概卷积上界达到1e14就需要预处理了。
至于更多FFT技巧,可以移步myy2016的集训队论文。
我们不得不承认FFT是一个优秀而鬼畜的东西。
因为有三角函数和浮点数的参与,FFT有时候会出现尴尬的爆精度现象。
这种病医生说是救不了的。
有些题目要求答案要对一个质数取模(998244353),我们知道取模是数论内容。
那么有没有什么东西可以替代单位复根呢?
当然有!原根!
设原根为g。
Wnn≡gP-1≡1(mod P);
所以可以把g(P-1)/n看成Wn的等价。
好的NTT学完了。
所以说这种质数必须是NTT质数(费马质数),即(P-1)有超过序列长度的2的正整数幂因子的质数,如998244353,1004535809,469762049等。
不是这种质数怎么办?找几个找乘积大于p^2*n的费马质数做,用中国剩余定理合并就好了。
#include <iostream> #include <cstdio> #include <cstdlib> #include <algorithm> #include <vector> #include <cstring> #include <queue> #include <complex> #include <stack> #define LL long long int #define ls (x << 1) #define rs (x << 1 | 1) #define MID int mid=(l+r)>>1 using namespace std; const int N = 300010; const int Mod = 998244353; int n,m,L,R[N],g[N],a[N],b[N]; int gi() { int x=0,res=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')res*=-1;ch=getchar();} while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*res; } inline int QPow(int d,int z) { int ans=1; for(;z;z>>=1,d=1ll*d*d%Mod) if(z&1)ans=1ll*ans*d%Mod; return ans; } inline void NTT(int *A,int f) { for(int i=0;i<n;++i)if(i<R[i])swap(A[i],A[R[i]]); for(int i=1;i<n;i<<=1){ int gn=QPow(3,(Mod-1)/(i<<1)),x,y; for(int j=0;j<n;j+=(i<<1)){ int g=1; for(int k=0;k<i;++k,g=1ll*g*gn%Mod){ x=A[j+k];y=1ll*g*A[i+j+k]%Mod; A[j+k]=(x+y)%Mod;A[i+j+k]=(x-y+Mod)%Mod; } } } if(f==1)return;reverse(A+1,A+n); int y=QPow(n,Mod-2); for(int i=0;i<n;++i)A[i]=1ll*A[i]*y%Mod; } int main() { n=gi();m=gi(); for(int i=0;i<=n;++i)a[i]=gi(); for(int i=0;i<=m;++i)b[i]=gi(); m+=n;for(n=1;n<=m;n<<=1)++L; for(int i=0;i<n;++i)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); NTT(a,1);NTT(b,1); for(int i=0;i<n;++i)a[i]=1ll*a[i]*b[i]%Mod; NTT(a,-1); for(int i=0;i<=m;++i)printf("%d ",a[i]); printf("\n"); return 0; }