知识点:FFT详解
前言
FFT其实在很早的时候就已经接触到了,但是那个时候学起来有点仙,感觉这东西离实际解题的距离有点远,不如那些其他的数据结构那么直接。但是半年多下来的做题,发现FFT其实应用的十分广泛,并且很多数学题推出公式之后就可以套用FFT进行计算。所以对于FFT的理解也不能仅仅只是停留于背板子的阶段了,而应该更加深入的去理解它。
前置知识
复数的运算及性质,多项式
知识点讲解
概要
首先讲什么是FFT,FFT的全称为快速傅里叶变换(Fast Fourier Transformation),是基于离散傅里叶变换的快速求法。最基础的运用就是解决多项式相乘的问题,可以将朴素算法的\(O(n^2)\)优化成\(O(nlogn)\),是一种比较高效的方法。
多项式相乘的朴素算法
我们假设我们现在有两个多项式:
那么我们令这两个的多项式相乘为\(h(x)\),即:
那么我们可以得到:
所以我们就可以得到一个复杂度为\(O(n^2)\)的解法了,即枚举两个多项式的每一位的系数,那么第一个多项式第\(i\)次项的系数与第二个多项式第\(j\)次项的系数相乘得到的即为答案的第\(i+j\)项的系数。
系数表示法与点值表示法
首先,我们需要知道,多项式的一个表示方法为系数表示法,即一个n次的多项式可以表示为\(\sum_{i=0}^na_ix^i\)。这里的每一个\(a_i\)表示每一项的系数。然后我们发现,对于一个多项式,我们只会关心这个多项式的系数,而并不需要真正的记录下它的指数,因为在系数表示法中,\(a_i\)的指数就是\(i\)。所以一个多项式可以表示成这个样子:
然后我们就可以引出点值表示法了。点值表示法顾名思义就是用几个点来表示一个多项式。两点确定一条直线,三点可以确定一条抛物线,所以一个\(n\)次多项式就需要\(n+1\)个点来确定。而我们刚刚记录下的系数也刚刚好是\(n+1\)个,所以我们的多项式可以进一步的表示成这个样子:
然后多项式的乘积就可以这样表示:
这样子我们就把一个多项式转化成了一些离散的点,这样的过程就叫做离散傅里叶变换(DFT)。
然后将一些离散的点重新转化成一个多项式,这个过程就叫做离散傅里叶反变换(IDFT)。
这样子我们对于FFT就会有一个大致的思路了:先将原来的两个多项式进行DFT,转化成点值之后再进行相乘,最后做IDFT重新变成答案的多项式。这样在相乘过程中的复杂度就可以被优化成\(O(n)\)了。但是由于我们在相乘时用的是点值表达式,而解题的时候一般不会给你点值表达式,所以我们还需要将系数表示转化为点值表示,于是下面开始进入DFT与IDFT。
复数的引入
我们在一般的计算当中都不会对\(\sqrt{-1}\)进行定义,然而在复数中,\(\sqrt{-1}\)等于一个神奇的数:\(i\),这个数在复数的定义下相当于\(1\)的作用。下面列举一些有关\(i\)的计算:
复数分为实部和虚部,所以一个复数可以表示为\(a+bi\),这里的\(a\)表示的是这个复数的实部,\(bi\)表示的是这个复数的虚部。同时,一个复数也是可以用坐标系来表示的,只不过这个坐标系\(y\)轴的单位不再是1,而是\(i\),这样我们就可以引入极角与复数的乘法了。
一个复数一共有三种表示方法(在坐标系中):\(a+bi\),\((a,b)\),\((r,\theta)\)(这里的r表示到圆心的距离,\(\theta\)表示极角)。于是复数的乘法也可以定义了:
通过这个,我们对于另一种表示法也可以定义出它的乘法了:
对于一个复数\(a+bi\),\(r=\sqrt{a^2+b^2},\theta=arctan(\frac{b}{a})\)
复数\(a+bi\)与\(c+di\)相乘,\(r_1=\sqrt{a^2+b^2}\),\(r_2=\sqrt{c^2+d^2}\),\(\theta _1=arctan(\frac{b}{a})\),\(\theta _2=arctan(\frac{d}{c})\)
那么乘积为\((ac-bd)+(ad+bc)i\)
\(R=\sqrt{(ac-bd)^2+(ad+bc)^2}=\sqrt{a^2c^2+b^2d^2-2abcd+b^2c^2+a^2d^2+2abcd}=\sqrt{(a^2+b^2)(b^2+d^2)}=r1\times r2\)
\(\Theta=arctan(\frac{bc+ad}{ac-bd})=arctan(\frac{\frac{bc+ad}{ac}}{\frac{ac-bd}{ac}})=arctan(\frac{\frac{b}{a}+\frac{d}{c}}{1-\frac{bd}{ac}})=arctan(\frac{\frac{b}{a}+\frac{d}{c}}{1-\frac{b}{a}\cdot \frac{d}{c}})=arctan(\frac{b}{a})+arctan(\frac{d}{c})=\theta_1+\theta_2\)
所以\((r1,\theta1)\times(r2,\theta_2)=(r1\times r2,\theta_1+\theta_2)\),总结下来就是两个复数相乘,长度相乘,极角相加。由于c++自带的\(complex\)函数有点菜菜的,所以我们可以自己手写\(complex\)函数。然后我们可以想象一下,如果两个复数到原点的距离都为1的话,那么这两个复数相乘就相当于绕着以原点为圆心的,半径为1的圆进行旋转。
struct comp {
double r,i;
comp() {}
comp(double r,double i):r(r),i(i) {}
};
comp operator + (comp a,comp b) {
return comp(a.r+b.r,a.i+b.i);
}
comp operator - (comp a,comp b) {
return comp(a.r-b.r,a.i-b.i);
}
comp operator * (comp a,comp b) {
return comp(a.r*b.r-a.i*b.i,a.i*b.r+a.r*b.i);
}
单位复根
我们现在已经知道了答案多项式的点值表达式,我们现在所需要做的任务就是把这个点值表达式转变成系数表达式。最明显的一种方法就是用高斯消元法,暴力去解除这个方程的解,可是这样的复杂度又会变成\(O(n^2)\)了,因为我们在计算\(x_0,x_0^2,x_0^3\cdots x_0^n\)的时候会重复计算很多次,而这个计算在实数域中似乎是不可避免的,所以我们就可以回到复数域中。在复数中,我们需要的数是\(\omega^k=1\)的数,而由于复数的乘法的定义,我们可以发现,对于所有的\((1,\theta)\)的复数,他的\(i\)次方总是在同一个圆上,并且经过一定次数的乘方之后一定是可以等于1的,所以我们就可以有效的利用这种数,将这个数带入,就可以很妙妙的求值了。
我们定义这种\(\omega^k=1\)的数为\(k\)次单位复根,计作\(\omega_k^n\),而这个\(n\)实际上是一个序号,表示的是将所有的\(k\)次单位根按极角序进行排序之后从零开始编号的单位根。
而由于\(\omega_k^0=1\),所以我们只要知道\(\omega_k^1\),就可以计算出所有的\(k\)次单位根了。
有关定理的证明
有关单位复根有很多比较有用的性质,这里对这些性质进行介绍并证明。
基本性质
既然单位复根相乘满足“长度相乘,极角相加”,而单位复根的长度又是1,所以我们可以发现所有的\(k\)次单位根是均匀的排布在一个半径为1,以原点为圆心的圆上的。根据欧拉公式\(e^{\pi i}=-1\),而\(\omega_n^n=1\),所以\(\omega_n^n=(-1)^2=(e^{\pi i})^2=e^{2\pi i}\),所以:
这样我们就可以快速的算出我们需要的\(k\)次单位根了。
消去引理
先给出引理:
这个其实很好理解的,就居\(\omega_8^2\)来说吧,根据引理,我们可以得出\(\omega_4^1\),这就相当于我们原本把一个圆划分成了8份,然后取的是其中的2份,而这又与将圆划分成4份,取其中的1份是等价的,由此可以发现这个消去引理是正确的。
由此我们可以推出\(\omega_{2n}^n=\omega_2^1=-1\)。
折半引理
引理:
这个定理看起来是比较的玄,首先前半部分是毋庸置疑的,因为\(\omega_n^k=-\omega_n^{k+\frac{1}{2}n}\),所以这两个的平方也是相同的。然后考虑后半部分,我们还是以\(n=8\)为例,然后画出圆被分成八份的图像:
然后我们可以想象一下,如果我们将任意一个单位根进行平方,那么相当于他的极角扩大了一倍,而扩大后的单位根一定落在4次单位根上,并且单位根的编号与原来单位根的编号是相同的(这里只考虑\(k\leq \frac{1}{2}n\))。
求和引理
引理:
这个定理看起来很高妙,仔细看会发现这是一个等比数列求和,于是我们套用等比数列求和公式,那么对于第一个公式:
\(S=\frac{1-(\omega_n^k)^n}{1-\omega_n^k}=\frac{1-\omega_n^{nk}}{1-\omega_n^k}=\frac{1-1^k}{1-\omega_n^k}=0\)
对于第二个公式,就更加简单了:
\(S=\sum_{j=0}^{n-1}(\omega_n^{dn})^j=\sum_{j=0}^{n-1}1^j=n\)
这样这两个公式就推完了。
DFT
首先我们先回到原本的系数表示法:
将我们的单位复根带入之后,就会得到:
然后我们得到的\(\lbrace y_0,y_1,\cdots \rbrace\)就是\(\lbrace a_0,a_1,\cdots \rbrace\)的点值了。
现在知道了求DFT的方法了,我们就要开始想能够让DFT的复杂度在\(O(nlogn)\)的做法了。
这里用到的是分治的思想,分治计算所有出\(k\)次单位根所对应的点值,而由于我们一共需要\(n+1\)个点值,所以对于一个\(n\)次的多项式,我们需要计算的就是它的\(n\)次单位根对应的点值。分治的思想主要体现在将多项式分为奇数项与偶数项,那么一个多项式可以表示为:
令\(G(x)=(a_0+a_2x+a_4x^2+a_6x^3)\),\(H(x)=(a_1+a_3x+a_5x^2+a_7x^3)\)
那么\(f(x)=G(x^2) +x\times H(x^2)\)
当\(x=\omega^k\)时,\(DFT(f(x))_k=DFT(G(x^2) +\omega^k \times H(x^2))\)
而由于折半引理,我么可以得出:
\(DFT(f(x+\frac{n}{2}))_k=DFT(G(x^2) + (-\omega^k)\times H(x^2))\)
所以我们发现,对于一个多项式,我们只要求出了前一半部分的值,我们就可以推出后一半的值,这样我们利用了分治在DFT上做到了复杂度为\(O(nlogn)\)了。但是这样也是有一个问题的,就是这样子只能处理出多项式系数的个数为2的整数次幂的多项式,所以我们在做DFT的时候,还需要将不满2的整数次幂的多项式在前面补零,这样就能做到不影响结果并且正确计算出答案。
DFT的优化
在这个分治上,我们采用的是递归的形式,然而由于每一次递归都要下传数组,导致这样的常数非常大,空间也有了一些无意义上的消耗,所以我们希望的是能够采用迭代的形式。然后我们开始思考,这个分治的实际结果是什么。我们每次将数组中的数按奇偶分组,实际上就是对于二进制下的这一位进行1与0的分组。并且我们可以观察一下最后分组的结果,以长度为8为例:
或许这样还看不出来什么,但是我们将每个元素的下表用二进制表示之后,就可以发现一些有趣的事了:
我们可以发现,每一位的在分治之后的序号正好与分治之前的序号在二进制表示上是翻转过来的。于是我们就可以考虑预处理出每一位在分治之后的序号,这样就不用进行递归的过程,而是用迭代来代替了。对于预处理这个序号,我们可以类似DP一样的进行:
void GetRev() {
for(int i=0;i<lim;i++) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
}
}
这样就得出了我们在翻转之后,每一位的序号了,接下来就是简单的迭代了。
IDFT
处理到这里,我们的FFT就只剩下最后的一步了,就是用IDFT把点值表达式转化成系数表达式。首先我们先考虑我们是如何从系数表达式变成点值表达式的。我们是将单位复根带入原来的系数表达式来计算点值的,我们可以将单位复根带入之后的n个多项式的系数用一个矩阵来表示:
现在我们想要让这所有的多项式的系数还原成用\(a\)表示的多项式,我们就必须在这个矩阵之后乘上一个这个矩阵的逆矩阵。首先我们要先了解逆矩阵的定义。一个矩阵的逆矩阵\(V\)定义:
其中\(I\)为单位矩阵,单位矩阵的样子是这样的:
也就是说这个单位矩阵的主对角线都是1,其余全是0。
在这里我先给出这个矩阵的逆矩阵长什么样:
然后我们开始证明这个矩阵是原矩阵的逆矩阵:
首先我们可以知道原矩阵位于\((j,k)\)位置的值为\(\omega_n^{jk}\),而位于逆矩阵中的\((j,k)\)处的值为\(\omega_n^{-jk}\),于是我们可以计算出\([V\cdot V^{-1}]\)的第\((j_1,j_2)\)位置上的值为:
然后根据求和引理,我们发现只有当\(j_1-j_2=0\)时,\([V\cdot V^{-1}]_{j_1j_2}=1\),其他时候\([V\cdot V^{-1}]_{j_1j_2}=0\),这样就可以证明上图的矩阵就是我们要乘的逆矩阵了。
我们将矩阵带入公式,就可以得到\(a_j=\frac{1}{n}\sum_{k=0}^{n-1}y_k\omega_n^{-jk}\)(这里的\(y_k\)是DFT中的\(y_k\))。然后我们就会惊奇的发现IDFT与DFT在公式上实际上是差不多的:
所以我么完全可以用同一个函数传不同的参数进去进行计算:
void FFT(comp *a,int IDFT) {//IDFT传进来时为-1,DFT传进来时为1
for(int i=0;i<lim;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int mid=1;mid<lim;mid<<=1) {
comp w=comp(cos(PI/mid),IDFT*sin(PI/mid));//IDFT是单位复根应该取负,而DFT时单位复根应该取正的。
for(int l=mid<<1,j=0;j<lim;j+=l) {
comp wn=comp(1.0,0.0);
for(int k=0;k<mid;k++) {
comp x=a[k+j];
comp y=a[k+j+mid]*wn;
a[k+j]=x+y;
a[k+j+mid]=x-y;
wn=wn*w;
}
}
}//这里没有除以n是因为在最后除了
}
到这里FFT就已经讲的差不多了,如果仍然没有听明白的话,那么可以移步这里。
AC代码(luogu3803)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
bool Finish_read;
template<class T>inline void read(T &x){Finish_read=0;x=0;int f=1;char ch=getchar();while(!isdigit(ch)){if(ch=='-')f=-1;if(ch==EOF)return;ch=getchar();}while(isdigit(ch))x=x*10+ch-'0',ch=getchar();x*=f;Finish_read=1;}
template<class T>inline void print(T x){if(x/10!=0)print(x/10);putchar(x%10+'0');}
template<class T>inline void writeln(T x){if(x<0)putchar('-');x=abs(x);print(x);putchar('\n');}
template<class T>inline void write(T x){if(x<0)putchar('-');x=abs(x);print(x);}
/*================Header Template==============*/
const int maxn=5e6+500;
const double PI=acos(-1);
int n,m;
int rev[maxn];
int lim=1,len;
/*==================Define Area================*/
struct comp {
double r,i;
comp() {}
comp(double r,double i):r(r),i(i) {}
}a[maxn],b[maxn];
comp operator + (comp a,comp b) {
return comp(a.r+b.r,a.i+b.i);
}
comp operator - (comp a,comp b) {
return comp(a.r-b.r,a.i-b.i);
}
comp operator * (comp a,comp b) {
return comp(a.r*b.r-a.i*b.i,a.i*b.r+a.r*b.i);
}
void GetRev() {
for(int i=0;i<lim;i++) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
}
}
void FFT(comp *a,int IDFT) {
for(int i=0;i<lim;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int mid=1;mid<lim;mid<<=1) {
comp w=comp(cos(PI/mid),IDFT*sin(PI/mid));
for(int l=mid<<1,j=0;j<lim;j+=l) {
comp wn=comp(1.0,0.0);
for(int k=0;k<mid;k++) {
comp x=a[k+j];
comp y=a[k+j+mid]*wn;
a[k+j]=x+y;
a[k+j+mid]=x-y;
wn=wn*w;
}
}
}
}
int main() {
read(n);read(m);
while(lim<=n+m) lim<<=1,len++;
GetRev();
for(int i=0;i<=n;i++) scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++) scanf("%lf",&b[i].r);
FFT(a,1);
FFT(b,1);
for(int i=0;i<lim;i++) {
a[i]=a[i]*b[i];
}
FFT(a,-1);
for(int i=0;i<=m+n;i++) {
printf("%d ",(int)(a[i].r/lim+0.5));
}
return 0;
}