FFT笔记
FFT笔记
前言:
这个算法对于我来讲比较抽象、高深,因为里面涉及了一些复数等一些对我而言很难很难的知识。
终于,花了几节文化课的时间冥思苦想,终于算是搞懂一点了。所以我决定趁脑子清醒的时候记录下来。
与其他文章不同的是,本文可能没有太多的公式证明,主要是以通俗易懂的方式去讲解,也是为了方便大家(包括以后可能忘记的我)更好地理解。建议:在电脑前放好演算纸与笔。
注:本篇文章是我这个小蒟弱写的,真正的dalao请看个玩笑便好,不必争论对错(但是欢迎指出文章存在的小错误)。
FFT 有什么用
快速傅里叶变换(FFT)可以在\(O(n\log n)\)的时间解决两个多项式的相乘问题。
为什么要用 FFT
我们先假设一个要计算多项式乘多项式的情景:
设
\(A(x)=a_0+a_1x+a_2x^2+a_3x^3+\dots\)
\(B(x)=b_0+b_1x+b_2x^2+b_3x^3+\dots\)
假设已经得到结果多项式为\(C(x)=c_0+c_1x+\dots\)
根据我们在数学中的做法可以得到:
- \(\;x\)的指数是由\(A(x),B(x)\)中任意两项的指数相加得到的
- \(\;c_n\)的组成是由能得到\(x^n\)的项的系数相乘得到的
如果不能理解,建议推一下下面这个式子:
那么多项式的乘法用计算机实现,第一个想到的就是\(O(n^2)\)的朴素算法(两层循环模拟)。针对于这种普通的多项式乘法,我们把它称作系数表示法。因为一旦能确定所有系数,多项式也随之确定。
如果我们把多项式和函数图像融合在一起,会发现一个多项式\(A(x)=a_0+a_1x+a_2x^2+a_3x^3+\dots+a_nx^n\)可以用\(y=a_0+a_1x+a_2x^2+a_3x^3+\dots+a_nx^n\)的函数图像所表示。
那根据确定函数的性质,如果我们可以知道\(n+1\)个\(y\)值,那么我们可以唯一地确定这个多项式。像这种方法,我们称之为点值表示法。
有趣的是,点值表示法时间复杂度是\(O(n)\)的!实现:因为\(C(x)=A(x)\times B(x)\),所以\(O(n)\)暴力枚举\(x_i\)即可。
相信此时的你蠢蠢欲动:如果我们将系数表示法转换为点值表示法,那不就可以大大降低我们多项式相乘的时间复杂度了吗?!
……遗憾的是,很难转换。无论是朴素算法\(O(n^2)\)转换,或是——可能你还不知道的拉格朗日插值法转换,时间复杂度没有降下来,仍然是\(O(n^2)\)。
那现在思路很清晰,我们要想将多项式相乘速度提快,就必须将系数表示法转换为点值表示法这一大关打下来!(也要完成逆转换)
在世界第一台计算机(1946年)横空出世的139年前(1807年),一位伟大的数学家信誓旦旦——
傅里叶:这个我会!
DFT(离散FFT/朴素FFT)
这个办法就是:用\(n\)个模长为\(1\)的复数代替点值表示法中\(n\)个\(x\)(下文默认\(n\)为\(2\)的幂)。
你可能觉得复数特别厉害,高中知识,溜了溜了。但是我们理不理解复数并没有直接关系。
复数:
“如果你学过了,可以跳过。
如果你不会复数,可以当做是向量。
如果你不会向量,可以看成是平面直角坐标系上的一个点。
如果你不会平面直角坐标系,可以看成是c++中的pair容器。
如果你还是什么都不会,…………(出门右拐先学习一下平面直角坐标系吧……”复数有一个实部还有一个虚部,类似于一个向量(或点)的横纵坐标。例如复数\(3+2i\),\(3\)是实部,\(2\)是虚部,\(i=\sqrt {-1}\)。可以想象成向量\((3,2)\)或点\((3,2)\)。
复数的运算规则是:模长相乘,幅角相加。如果你只是想学FFT,记住幅角相加就好了。
那为什么要用复数呢?有趣在于——它是一种数,可以带入多项式\(A(x)\)中去,显然你不能把一个向量(或点)带入一个多项式。
更有趣的是,c++提供了复数的模板!
头文件:#include<complex>
定义:complex<double> a[1000],b,c;
运算:直接加减乘除
上面说了要找\(n\)个模长为\(1\)的复数,可是这可不是乱找的。傅里叶精心地挑选了\(n\)个点,而这\(n\)个点,实在平面直角坐标系中,将一个单位圆平均分成\(n\)等分,这\(n\)个点的横坐标为实部、纵坐标为虚部,便可以构成\(n\)个虚数。
详见图:
从\((1,0)\)开始,然后逆时针从\(0\)开始编号,第\(k\)个点记作虚数\(\omega_n^k\)。还记得模长相乘,幅角相加吗?所以\(\omega_n^k\)是\(\omega_n^1\)的\(k\)次方,因此\(\omega_n^1\)就是单位根。
根据每个复数的幅角,可以确定这个向量或点的位置。\(\omega_n^k\)对应的向量或点坐标为\((\cos{\frac{k}{n}2\pi},\sin{\frac{k}{n}2\pi})\),那么这个复数也就为\(\cos{\frac{k}{n}2\pi}+i\times \sin{\frac{k}{n}2\pi}\)。
那么把\(n\)个\(\omega_n^0,\omega_n^1,\dots,\omega_n^{n-1}\)带入多项式,就可以得到特殊的点值表示法。傅里叶开心地将这种点值表示称为离散傅里叶变换!
为什么要使用单位根代入
这肯定是有讲究的,因为这里有一些有趣的性质。
相关证明就不写了,直接上结论:
将多项式\(A(x)\)的傅里叶变换结果作为多项式\(B(x)\)的系数,再将单位根的倒数作为\(x\)代入\(B(x)\),得到的每个数再\(\div n\),得到的正是\(A(x)\)的各项系数!
总而言之,这也顺带完成了将点值表示法逆转换成系数表示法的任务,这也是离散傅里叶变换神奇的特殊性质。
FFT
傅里叶发明了如此高深的DFT,完成了我们的主要任务——完成转换与逆转换。但是DFT仍是朴素的\(O(n^2)\)……
傅里叶:我都没见过计算机,为什么要优化时间复杂度……
但是,运用信息的思维,可以立马想到分治!因此快速傅里叶变换由此应运而生!
显然也是可以证明的,但是因为太懒、掌握不牢固,所以决定待以后再证。
分治可以用dfs实现,边界条件为\(n=1\)时,直接return。
递归实现FFT
code:
#define cp complex<double>
cp omega(int n, int k)
{
return cp(cos(2 * PI * k / n), sin(2 * PI * k / n));
}
void fft(cp *a, int n, bool inv)
{
if(n == 1) return;
static cp buf[N];
int m = n / 2;
for(int i = 0; i < m; i++) //将每一项按照奇偶分为两组
{
buf[i] = a[2 * i];
buf[i + m] = a[2 * i + 1];
}
for(int i = 0; i < n; i++)
a[i] = buf[i];
fft(a, m, inv); //递归处理两个子问题
fft(a + m, m, inv);
for(int i = 0; i < m; i++) //枚举x,计算A(x)
{
cp x = omega(n, i);
if(inv) x = conj(x);
//conj是一个自带的求共轭复数的函数,精度较高。当复数模为1时,共轭复数等于倒数
buf[i] = a[i] + x * a[i + m]; //根据之前推出的结论计算
buf[i + m] = a[i] - x * a[i + m];
}
for(int i = 0; i < n; i++)
a[i] = buf[i];
}
inv表示单位根\(\omega_n^1\)取不取倒数。
然鹅这个只是1.0版本,让我们学一些优化吧!
优化实现FFT
迭代版本非递归FFT
在递归时我们不断分组分成两侧,观察一下有没有什么规律:
这些“\(|\)”表示分割线。
明显地(或者说非常不明显地),将它们转换成二进制后有镜面效果:
这里的镜面反转指的是同一个位置进行了变化。
那么我们可以把每位数放在最后一个位置上,然后不断向上还原,同时求出点值表示。
代码部分建议跟着代码手推一下,因为全都是位运算,有点难理解。
镜面操作单独code1:
for(int i = 0; i < n; i++)
{
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]);
// i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
镜面操作单独code2(这个代码可以不用到lim变量):
for(int i=0,j=0;i<n;++i)
{
if(i>j) swap(a[i],a[j]);
for(int l=n>>1;(j^=l)<l;l>>=1);
}
融合进FFT的镜面操作code(调用的是操作1):
cp a[N], b[N], omg[N], inv[N];
int lim;
void init()
{
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++)
{
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
// conj这一行等价于 inv[i]=cp(cos(2*PI*i/n),-sin(2*PI*i/n));
}
}
void fft(cp *a, cp *omg)
{
for(int i = 0; i < n; i++) // 这一个for就是用位运算进行二进制的镜面操作
{
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]);
// i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
static cp buf[N];
for(int l = 2; l <= n; l *= 2)
{
int m = l / 2;
for(int j = 0; j < n; j += l)
for(int i = 0; i < m; i++)
{
buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
// 为什么是omg[n / l * i]?因为这是在用递推模拟递归,只有这样写才符合上面递归代码中n的变化
}
for(int j = 0; j < n; j++)
a[j] = buf[j];
}
}
如果我们预先处理好\(\omega_n^k\)与\(\omega_n^{-k}\)并存入omg,inv数组里,那我们就可以根据需求在fft函数中传入不同的数组。
蝴蝶操作
这个优化听起来贼nb,但我们可以从本质上思考一下:buf[]有何用?
是为了做这一件事:
a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
a[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
同时不能使这两行互相干涉,所以我们才需要buf[]。
但是如果我们用一个临时变量代替:
cp t = omg[n / l * i] * a[j + i + m];
a[j + i + m] = a[j + i] - t;
a[j + i] = a[j + i] + t;
就没有任何问题啦!!
全样长这样:
cp a[N], b[N], omg[N], inv[N];
int lim;
void init()
{
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++)
{
omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
// conj这一行等价于 inv[i]=cp(cos(2*PI*i/n),-sin(2*PI*i/n));
}
}
void fft(cp *a, cp *omg)
{
for(int i = 0; i < n; i++)
{
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1) t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]);
// i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2)
{
int m = l / 2;
for(cp *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++)
{
cp t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
现在,这个终极无敌优化版FFT就比原递归FFT快得多了。
后记&版题
终于!笔记写完啦!
本篇文章借鉴了一些内容,是源于这位dalao的这篇博客,非常感谢!!
下面我也贴一个板子吧(^^)!
code:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define cp complex<double>
#define rp(i,o,p) for(ll i=o;i<=p;++i)
#define pr(i,o,p) for(ll i=o;i>=p;--i)
const ll MAXN=1e6+5;
const double PI=acos(-1.0);
ll n=1;
char s1[MAXN],s2[MAXN];
ll la,lb,ans[MAXN];
cp a[MAXN],b[MAXN],omg[MAXN],inv[MAXN];
void init()
{
for(ll i=0;i<n;++i)
{
omg[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
inv[i]=conj(omg[i]);
// <=> inv[i]=cp(cos(2*PI*i/n),-sin(2*PI*i/n));
}
}
void fft(cp *a,cp *omg)
{
ll lim=0;
while((1<<lim)<n) ++lim;
for(ll i=0;i<n;++i)
{
ll t=0;
for(ll j=0;j<lim;++j)
if(((i>>j)&1))
t|=(1<<(lim-j-1));
if(i<t)
swap(a[i],a[t]);
}
for(ll l=2;l<=n;l<<=1)
{
ll m=l>>1;
for(cp *p=a;p!=a+n;p+=l)
{
for(ll i=0;i<m;++i)
{
cp t=omg[n/l*i]*p[i+m];
p[i+m]=p[i]-t;
p[i]+=t;
}
}
}
}
int main()
{
scanf("%s%s",s1,s2);
la=strlen(s1),lb=strlen(s2);
while(n<la+lb) n<<=1;
for(ll i=0;i<la;++i)
a[i].real(s1[la-i-1]-'0');
for(ll i=0;i<lb;++i)
b[i].real(s2[lb-i-1]-'0');
init();
fft(a,omg);
fft(b,omg);
for(ll i=0;i<n;++i)
a[i]*=b[i];
fft(a,inv);
for(ll i=0;i<n;++i)
{
ans[i]+=(ll)(a[i].real()/n+0.5);
ans[i+1]+=ans[i]/10;
ans[i]%=10;
}
ll i;
for(i=la+lb-1;ans[i]==0&&i>=0;--i)
if(i==0)
putchar('0'),i=-1;
while(i>=0)
putchar('0'+ans[i--]);
putchar('\n');
return 0;
}