快速傅立叶变换(FFT)
多项式
系数表示法
设\(f(x)\)为一个\(n-1\)次多项式,则 \(f(x)=\sum\limits_{i=0}^{n-1}a_i*x^i\)
其中\(a_i\)为\(f(x)\)的系数,用这种方法计算两个多项式相乘(逐位相乘)复杂度为\(O(n^2)\)
点值表示法
根据小学知识,一个\(n-1\)次多项式可以唯一地被\(n\)个点确定
即,如果我们知道了对于一个多项式的\(n\)个点\((x_1,y_1),(x_2,y_2)……(x_n,y_n)\)
那么这个多项式唯一满足,对任意\(1\le i \le n\),满足\(y_i=\sum\limits_{j=0}^{n-1}a_j*x_i^j\)
那么用点值实现多项式相乘是什么复杂度呢?
首先我们需要选\(n\)个点,每个点需要求出其在多项式中的值,复杂度为\(O(n^2)\)
然后把两个点值表示的多项式相乘,由于\(c(x_i)=a(x_i)*b(x_i)\),复杂度为\(O(n)\)
最后插值法用点值求出系数,复杂度为\(O(n^2)\)(我还不会插值)
考虑如果可以快速实现点值转系数和系数转点值,岂不是可以快速计算多项式乘法(说的简单,你倒是告诉我怎么快速转化啊)
前置芝士
复数
定义虚数单位\(i=\sqrt{-1}\),\(a,b\)为实数,则形如\(a+bi\)的数叫复数
其中\(a\)为复数的实部,\(b\)为复数的虚部
在复平面中,复数可以被表示为向量,所以和向量具有很多相似的性质,我们也可以用向量来理解复数,但是复数具有更多性质,比如作为一个数代入多项式
其中模长定义为\(\sqrt{a^2+b^2}\),幅度定义为\(x\)轴正半轴到向量转角的有向角
复数运算法则:
加减法与向量相同,重点是乘法:
几何定义为:模长相乘,幅角相加
代数定义为:
单位根
我们首先定义圆心为坐标原点,半径为\(1\)的圆叫做单位圆
我们将这个圆\(n\)等分,得到\(n\)个圆上的点,以这\(n\)个圆上点的横坐标作为实部,纵坐标作为虚部,就得到了\(n\)个复数
网上扒的图片,侵删\(QwQ\)
首先我们不自找麻烦,以\((1,0)\)作为这\(n\)个点的起点,记作\({\omega _n}^0\),逆时针方向第\(k\)个点记作\({\omega _n}^k\)
根据模长相乘幅角相加,我们可以看出\(\omega _n^k\)是\(\omega _n^0\)的\(k\)次方,其中\(\omega _n^1\)被称为\(n\)次单位根
根据幅角,我们可以计算出\(\omega _n^k\)表示的复数为\(cos{\frac{k}{n}2\pi}+i*sin{\frac{k}{n}2\pi}\)
单位根具有一些性质:
\(1.\omega _n^k=cos{\frac{k}{n}2\pi}+i*sin{\frac{k}{n}2\pi}=e^{\frac{2\pi ki}{n}}\)
证明:这个第一步到第二步由定义得出,第二步到第三步由欧拉公式得出
\(2.\omega _{2n}^{2k}=\omega_{n}^{k}\)
证明:\(\omega _{2n}^{2k}=e^{\frac{2\pi 2ki}{2n}}=e^{\frac{2\pi ki}{n}}=\omega_{n}^{k}\)
\(3.\omega _{n}^{k+\frac{n}{2}}=-\omega_{n}^{k}\)
证明:\(\omega _{n}^{k+\frac{n}{2}}=\omega _{2n}^{2k+n}=\omega _{2n}^{2k}*\omega _{2n}^{n}=\omega_{n}^{k}*(cos\pi+i*sin\pi)=-\omega_{n}^{k}\)
\(4.\omega _{n}^{0}=\omega _{n}^{n}=1\)
证明:不用证了吧……
正文之前
这段话有可能有助于您理解本算法:
傅立叶这个大神仙根本就没见过计算机长什么样,所以他提出的傅立叶变换和逆变换只是一种将系数转点值和将点值转系数的方法,没有任何降低复杂度的功效,至于快速傅立叶变换是后人再研究傅立叶变换发现的一种加速方法,是对\(DFT\)和\(IDFT\)的优化
离散傅立叶变换(DFT)
假设\(f(x)=\sum\limits_{i=0}^{n-1}a_i*x_i\)
\(DFT(a)=(f(1),f(\omega _{n}),f(\omega _{n}^{2}),……,f(\omega _{n}^{n-1}))\)
通俗点说,就是对于一个系数表示法的多项式\(f(x)\),将\((1,\omega _{n},\omega _{n}^{2},……,\omega _{n}^{n-1})\)带入求出该多项式的点值表示法
离散傅立叶逆变换(IDFT)
将\(f(x)\)在\(n\)个\(n\)次单位根处的点值表示转化为系数表示
这里就可以回答,为什么我们要让\(n\)次单位根作为\(x\)代入多项式
假设\((y_0,y_1,y_2,……,y_{n-1})\)是多项式\(A(x)=\sum\limits_{i=0}^{n-1}a_i*x_i\)的离散傅立叶变换
我们另有一个多项式\(B(x)=\sum\limits_{i=0}^{n-1}=y_i*x_i\)
将上述\(n\)次单位根的倒数\((1,\omega _{n}^{-1},\omega _{n}^{-2},……,\omega _{n}^{-(n-1)})\)代入\(B(x)\)得到新的离散傅立叶变换\((z_0,z_1,z_2,……,z_{n-1})\)
则我们发现
对于\(\sum\limits_{i=0}^{n-1}(\omega _n^{j-k})^i\)我们单独考虑:
当\(j-k=0\)时
答案为\(n\)
当\(j\ne k\)时
等比数列求和得到\(\frac{(\omega _n^{j-k})^n-1}{\omega _n^{j-k}-1}=\frac{(\omega _n^n)^{j-k}-1}{\omega _n^{j-k}-1}=\frac{1^{j-k}-1}{\omega _n^{j-k}-1}=0\)
所以
即
得出结论:对于以\(A(x)\)的离散傅立叶变换作为系数的多项式\(B(x)\),取单位根的倒数\((1,\omega _{n}^{-1},\omega _{n}^{-2},……,\omega _{n}^{-(n-1)})\)作为\(x\)代入,再将结果除以\(n\)即为\(A(x)\)的系数
这个结论实现了将多项式点值转化为系数
快速傅立叶变换
我们一顿分析最后发现复杂度……仍然是\(O(n^2)\)
我们学这破玩意的意义不就是降低复杂度嘛,所以我们接下来讲怎么降复杂度
我们先设\(A(x)=\sum\limits_{i=0}^{n-1}a_i*x_i\)
我们将\(A(x)\)的下标按奇偶分类,得到
设两个多项式
那么就可以发现\(A(x)=A_1(x^2)+xA_2(x^2)\)
将\(x=\omega _n^{k}(k<\frac{n}{2})\)代入
将\(x=\omega _n^{k+\frac{n}{2}}(k<\frac{n}{2})\)代入
我们一点也不惊喜地发现,只要求出\(A_1(x)\)和\(A_2(x)\)在\((\omega _{\frac{n}{2}}^0,\omega _{\frac{n}{2}}^1,……,\omega _{\frac{n}{2}}^{\frac{n}{2}-1})\)的点值表示,就可以\(O(n)\)地求出\(A(x)\)在\((1,\omega _{n},\omega _{n}^{2},……,\omega _{n}^{n-1})\)
所以我们可以递归实现\(O(nlogn)\)求解多项式乘法了
注意:我们假设\(f*g=h\),那么对于\(f\)和\(g\)都要直接求出大于\(n+m+1\)个的\(2^k\)个点值(由于分治要求,点数一定是\(2\)的整次幂)
\(code\)
#include<bits/stdc++.h>
using namespace std;
namespace red{
#define mid ((l+r)>>1)
#define eps (1e-8)
inline int read()
{
int x=0;char ch,f=1;
for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
if(ch=='-') f=0,ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return f?x:-x;
}
const int N=5e6+10;
const double pi=acos(-1.0);
int n,m;
struct complex
{
double x,y;
complex(double tx=0,double ty=0){x=tx,y=ty;}
inline complex operator + (const complex t) const
{
return complex(x+t.x,y+t.y);
}
inline complex operator - (const complex t) const
{
return complex(x-t.x,y-t.y);
}
inline complex operator * (const complex t) const
{
return complex(x*t.x-y*t.y,x*t.y+y*t.x);
}
}a[N],b[N];
inline void fft(int limit,complex *a,int inv)
{
if(limit==1) return;
complex a1[limit>>1],a2[limit>>1];
for(int i=0;i<limit;i+=2)
{
a1[i>>1]=a[i],a2[i>>1]=a[i+1];
}
fft(limit>>1,a1,inv);
fft(limit>>1,a2,inv);
complex Wn=complex(cos(2.0*pi/limit),inv*sin(2.0*pi/limit)),w=complex(1,0);
for(int i=0;i<(limit>>1);++i,w=w*Wn)
{
a[i]=a1[i]+w*a2[i];
a[i+(limit>>1)]=a1[i]-w*a2[i];
}
}
inline void main()
{
n=read(),m=read();
for(int i=0;i<=n;++i) a[i].x=read();
for(int i=0;i<=m;++i) b[i].x=read();
int limit=1;
while(limit<=n+m) limit<<=1;
fft(limit,a,1);
fft(limit,b,1);
for(int i=0;i<=limit;++i)
{
a[i]=a[i]*b[i];
}
fft(limit,a,-1);
for(int i=0;i<=n+m;++i) printf("%d ",(int)(a[i].x/limit+0.5));
}
}
signed main()
{
red::main();
return 0;
}
然而我们发现好像有点慢
迭代优化
众所周知递归比较慢,我们有没有什么方法可以用迭代代替递归呢?
扒图时间到
通过一顿找规律,我们根本不能发现,每个数字在分治后的位置就是它所在位置的二进制翻转
这个规律也有一个好听的名字,叫蝴蝶定理
那么我们只要预处理出每个数字在最后一次递归中的位置,然后自底向上合并,岂不是可以摆脱递归
#include<bits/stdc++.h>
using namespace std;
namespace red{
#define eps (1e-8)
inline int read()
{
int x=0;char ch,f=1;
for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
if(ch=='-') f=0,ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return f?x:-x;
}
const int N=5e6+10;
const double pi=acos(-1.0);
int n,m;
int limit=1,len;
int pos[N];
struct complex
{
double x,y;
complex(double tx=0,double ty=0){x=tx,y=ty;}
inline complex operator + (const complex t) const
{
return complex(x+t.x,y+t.y);
}
inline complex operator - (const complex t) const
{
return complex(x-t.x,y-t.y);
}
inline complex operator * (const complex t) const
{
return complex(x*t.x-y*t.y,x*t.y+y*t.x);
}
}a[N],b[N],buf[N];
inline void fft(complex *a,int inv)
{
for(int i=0;i<limit;++i)
if(i<pos[i]) swap(a[i],a[pos[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
complex Wn(cos(pi/mid),inv*sin(pi/mid));
for(int r=mid<<1,j=0;j<limit;j+=r)
{
complex w(1,0);
for(int k=0;k<mid;++k,w=w*Wn)
{
buf[j+k]=a[j+k]+w*a[j+k+mid];
buf[j+k+mid]=a[j+k]-w*a[j+k+mid];
}
}
for(int i=0;i<limit;++i) a[i]=buf[i];
}
}
inline void main()
{
n=read(),m=read();
for(int i=0;i<=n;++i) a[i].x=read();
for(int i=0;i<=m;++i) b[i].x=read();
while(limit<=n+m) limit<<=1,++len;
for(int i=0;i<limit;++i)
pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
fft(a,1);
fft(b,1);
for(int i=0;i<=limit;++i) a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=n+m;++i) printf("%d ",(int)(a[i].x/limit+0.5));
}
}
signed main()
{
red::main();
return 0;
}
蝴蝶操作
考虑这里
for(int r=mid<<1,j=0;j<limit;j+=r)
{
complex w(1,0);
for(int k=0;k<mid;++k,w=w*Wn)
{
buf[j+k]=a[j+k]+w*a[j+k+mid];
buf[j+k+mid]=a[j+k]-w*a[j+k+mid];
}
}
for(int i=0;i<limit;++i) a[i]=buf[i];
之所以加\(buf\)数组是因为两次赋值\(a\)的值会变化,我们可以提前存储\(a\)数组的值,然后优化掉\(buf\)数组
for(int k=0;k<mid;++k,w=w*Wn)
{
complex x=a[j+k],y=w*a[j+k+mid];
a[j+k]=x+y;
a[j+k+mid]=x-y;
}
三次变两次优化
观察到上面的代码我们跑了三次肥肥兔,现在我们有一种方法可以少跑一次
假设我们求\(f(x)*g(x)\)
设复多项式\(h(x)=f(x)+i*g(x)\),实部为\(f(x)\),虚部为\(g(x)\)
那么\(h(x)^2=(f(x)+i*g(x))^2=f(x)^2-g(x)^2+i*2*f(x)*g(x)\)
我们只要把\(h(x)^2\)的虚部除以\(2\)就得到了结果
完全版:
#include<bits/stdc++.h>
using namespace std;
namespace red{
#define eps (1e-8)
inline int read()
{
int x=0;char ch,f=1;
for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
if(ch=='-') f=0,ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return f?x:-x;
}
const int N=5e6+10;
const double pi=acos(-1.0);
int n,m;
int limit=1,len;
int pos[N];
struct complex
{
double x,y;
complex(double tx=0,double ty=0){x=tx,y=ty;}
inline complex operator + (const complex t) const
{
return complex(x+t.x,y+t.y);
}
inline complex operator - (const complex t) const
{
return complex(x-t.x,y-t.y);
}
inline complex operator * (const complex t) const
{
return complex(x*t.x-y*t.y,x*t.y+y*t.x);
}
}a[N];
inline void fft(complex *a,int inv)
{
for(int i=0;i<limit;++i)
if(i<pos[i]) swap(a[i],a[pos[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
complex Wn(cos(pi/mid),inv*sin(pi/mid));
for(int r=mid<<1,j=0;j<limit;j+=r)
{
complex w(1,0);
for(int k=0;k<mid;++k,w=w*Wn)
{
complex x=a[j+k],y=w*a[j+k+mid];
a[j+k]=x+y;
a[j+k+mid]=x-y;
}
}
}
}
inline void main()
{
n=read(),m=read();
for(int i=0;i<=n;++i) a[i].x=read();
for(int i=0;i<=m;++i) a[i].y=read();
while(limit<=n+m) limit<<=1,++len;
for(int i=0;i<limit;++i)
pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
fft(a,1);
for(int i=0;i<=limit;++i) a[i]=a[i]*a[i];
fft(a,-1);
for(int i=0;i<=n+m;++i) printf("%d ",(int)(a[i].y/limit/2+0.5));
}
}
signed main()
{
red::main();
return 0;
}
注意三次变两次优化会令精度误差平方,请根据题目值域考虑是否使用
参考博客
%%%attack
%%%rabbithu