快速傅立叶变换
快速傅立叶变换
1 引入
现在有两个多项式 \(f(x)\),\(g(x)\):
我们要求出两者相乘的结果,按照多项式相乘的运算法则,把每一项相乘,总复杂度为 \(O(n^2)\)。
这是传统算法能到达的最好复杂度。能不能进行优化呢?
使用快速傅立叶变换,我们可以实现 \(O(n\log n)\) 的复杂度。
1.1 概念
快速傅里叶变换(FFT),即运用计算机计算离散傅立叶变换 (DFT)的快速算法的统称。于 1965 年提出。
2 多项式表示
多项式的常见表示方法有两种:系数表示和点值表示。
系数表示就是最常见的表示方式,例如上面的 \(f(x)\),可以表示为 \((5,3,7)\)。
而点值表示法,则是在平面上取函数上的点。有一个定理:
\(n\) 次多项式可以有 \(n+1\) 个点唯一确定。
于是上面的 \(f(x)\) 也可表示为 \(\{(0,7),(1,15),(-1,9)\}\)。
而利用点值表示求多项式乘积很方便,取相同的 \(x\),然后将多项式的值相乘,得到新的点值,重复 \(n+m\) 次可以得到最终多项式的点值表示。
我们发现,点值表示是 \(O(n)\) 的,非常优秀。如果我们能在较短时间内将系数表示转化为点值表示,求出结果在转化回系数表示,就能快速求出多项式乘法。
3 单位根
关于复数:虚数与复数与欧拉公式 - 知乎 (zhihu.com)
以复平面上单位圆为起点,单位圆的 \(n\) 等分点为终点,可以唯一得到 \(n\) 个向量,也就是 \(n\) 个复数。设幅角为正数且最小的复数为 \(\omega_n\),称其为 \(n\) 次单位根。则有:
又由欧拉公式得:
特别的,\(\omega_n^0=\omega_n^n=1\)。
4 快速傅立叶变换
4.1 分治法实现
FFT 的基本想法是分治。对于 DFT 来说,它分治的求解 \(x=\omega_n^k\) 时 \(f(x)\) 的值。
对于多项式 \(f(x)=\sum\limits_{i=0}^{n-1}a_ix^i\),同时不妨设 \(n=2^k\)(缺失的部分用 \(a_i=0\) 补齐)。我们将其按 \(a_i\) 下标的奇偶性分开,也就是:
令
则有:
令 \(x=\omega_n^k\),利用偶次单位根性质 \(\omega_n^i=-\omega_n^{i+\frac n2}\),以及 \(f_1(x)\) 与 \(f_2(x)\) 都为偶函数,可以知道当 \(x=\omega_n^k\) 与 \(x=\omega_n^{k+\frac n2}\) 对应值相同,于是有:
同时也可以得到:
因此,在求出 \(f_1(\omega_{\frac n2}^k)\) 和 \(f_2(\omega_{\frac n2}^k)\),可以同时求出 \(f(\omega_n^k)\) 和 \(f(\omega_n^{k+\frac n2})\)。
然后,我们继续对 \(f_1\) 和 \(f_2\) 递归求解即可。这是 FFT 的一种实现:分治 DFT。
我们发现,递归求解时,必须满足 \(n=2^k\) 才能实现,这正是我们早在开头就提到的。
在带入的时候,我们要带入 \(n\) 个不同的值,所以直接带入 \(\omega_n^0,\omega_n^1,\omega_n^2,\cdots,\omega_n^{n-1}(n=2^k)\) 总共 \(2^k\) 个不同值。
代码实现上, STL 给出了复数模板 <complex>
。直接食用即可。
以上就是 FFT 中对于 DFT 的介绍,完成了我们在 2 结尾处所提到的第一步:系数表示转化为点值表示。
代码:
typedef complex<double> comp;
const comp i(0, 1);
const int Maxn = 1 << 20;
const double pi = acos(-1);
comp tmp[Maxn];
void DFT(comp *f, int n) {
if(n == 1) return ;
for(int i = 0; i < n; i++) {
tmp[i] = f[i];
}
for(int i = 0; i < n; i++) {
if(i & 1) {//偶数放左边,奇数放右边
f[n / 2 + i / 2] = tmp[i];
}
else {
f[i / 2] = tmp[i];
}
}
comp *g = f, *h = f + n / 2;//递归求解
DFT(g, n / 2);
DFT(h, n / 2);//分治
//↑:分治
//↓:合并分治
comp cur(1, 0), step(sin(2 * pi / n), sin(2 * pi / n));
//当前单位根为 cur,step 为两个单位根的差。
for(int k = 0; k < n / 2; k++) {
tmp[k] = g[k] + cur * h[k];
tmp[k + n / 2] = g[k] - cur * h[k];//推出的两个公式
cur *= step;
}
for(int i = 0; i < n; i++) {
f[i] = tmp[i];
}
}
时间复杂度 \(O(n\log n)\)。
4.2 倍增法实现
我们从上面的算法继续优化。我们使用递归会耗费更多额外内存。我们在数组中模拟递归中的拆分,然后倍增进行合并。
对于拆分,使用位逆序置换;
对于合并,使用蝶形运算优化。
4.2.1 位逆序置换
以八项多项式为例,拆分过程为(括号中数字为下标):
- 初始序列 \(\{0,1,2,3,4,5,6,7\}\);
- 第一次后 \(\{0,2,4,6\}\{1,3,5,7\}\);
- 第二次后 \(\{0,4\}\{2,6\}\{1,5\}\{3,7\}\);
- 第三次后 \(\{0\}\{4\}\{2\}\{6\}\{1\}\{5\}\{3\}\{7\}\)。
同时我们有规律:对于原先序列,将下标用二进制表示,反转后就是最终位置的下标。例如 \(1\) 是 \((001)_2\),反转后是 \((100)_2\),也就是 \(4\)。而最终 \(1\) 确实在 \(4\) 上。
这个非常巧妙的东西,就叫做位逆序置换,其实名字就是算法本身了。
(证明先咕)
位逆序置换有很简单的 \(O(n \log n)\) 方式,但是有一种更好的 \(O(n)\) 做法。
设 \(R(x)\) 为我们二进制反转后的数。首先有 \(R(0)=0\)。再设 \(len=2^k\),其中 \(k\) 为二进制数的长度。
从小到大求解 \(R(x)\)。这样保证了求 \(R(x)\) 时,\(R(\left\lfloor\dfrac{x}{2}\right\rfloor)\) 已经求出。因此,我们把 \(x\) 右移一位(除以 \(2\)),然后反转,再右移一位,就得到了 \(x\) 除二进制个位之外,其它位的反转结果。
考虑个位的反转结果:如果个位是 \(0\),那么最高位就是 \(0\);如果个位是 \(1\),那么最高位是 \(1\),此时还要加上 \(\dfrac{len}2=2^{k-1}\)。综上有:
举个例子:设 \(k=5\),\(len=32=(100000)_2\)。此时翻转 \((11001)_2\)。
- 右移一位,即 \((1100)_2\),补齐后是 \((01100)_2\),反转后\((00110)_2\),再右移一位得到 \((0011)_2\)。
- 由于个位为 \(1\),所以加上 \(2^{k-1}=(10000)_2\),即 \((10011)_2\).
代码:
int rev[Maxn];//R(x)
void change(comp y[], int len) {
for(int i = 0; i < len; i++) {//求翻转后数组
rev[i] = rev[i >> 1] >> 1;
if(i & 1) {
rev[i] += (len >> 1);
}
}
for(int i = 0; i < len; i++) {//下标交换
if(i < rev[i]) {//保证只交换一次
swap(y[i], y[rev[i]]);
}
}
}
4.2.2 蝶形运算优化
已知 \(f_1(\omega_{\frac n2}^k)\) 和 \(f_2(\omega_{\frac n2}^k)\) 后,用下面两式求出 \(f(\omega_n^k)\) 和 \(f(\omega_n^{k+\frac n2})\) :
使用位逆序置换后,对于给定 \(n,k\):
- \(f_1(\omega_{\frac n2}^k)\) 存储在数组下标为 \(k\) 的位置,\(f_2(\omega_{\frac n2}^k)\) 存储在数组下标为 \(k+\dfrac n2\) 的位置。
- \(f(\omega_n^k)\) 存储在数组下标为 \(k\) 的位置,\(f(\omega_n^{k+\frac n2})\) 存储在数组下标为 \(k+\dfrac n2\) 的位置。
因此可以直接再两个数组下标处进行覆盖,不避开额外数组。这就是蝶形运算。
最后,详细讲解一下倍增法求 FFT 的具体实现方式:
- 令当前段长为 \(s=\dfrac n2\)。
- 同时枚举序列 \(\{f_1(\omega_{\frac n2}^k)\}\) 的左端点 \(l_1=0,2s,4s,\cdots,n-2s\) 和序列 \(\{f_2(\omega_{\frac n2}^k)\}\) 的左端点 \(l_2=s,3s,5s,\cdots,n-s\)。
- 合并时,枚举 \(k=0,1,2,\cdots,s-1\),此时 \(f_1(\omega_{\frac n2}^k)\) 存储在数组下标为 \(l_1+k\) 的位置,\(f_2(\omega_{\frac n2}^k)\) 存储在数组下标为 \(l_2+k\) 的位置。
- 使用蝶形运算算出 \(f(\omega_n^k)\) 和 \(f(\omega_n^{k+\frac n2})\),然后直接覆盖。
代码在文末有。
5 快速傅里叶逆变换
我们在上文已经知道了求解 DFT 的 FFT,成功将系数表示转化为了点值表示。
接下来考虑求解 IDFT,即把点值表示转化为系数表示。
我们从线性代数的角度理解 IDFT 的过程。
首先,DFT 是一个线性变换,理解为目标多项式为向量,左乘一个矩阵得到变换后的向量,模拟带入单位根的过程:
现在我们已经得到左边结果了,中间的 \(x\) 值在目标多项式的点值中也一一对应。所以,根据矩阵的基础知识,我们只要在狮子两边同时左乘中间的逆矩阵即可。
而这个矩阵的逆矩阵也很特殊,就是每一项取倒数,在除以变换长度 \(n\),就能得到逆矩阵。
为了使计算结果为原来的倒数,根据欧拉公式,有:
因此我们可以把单位根 \(\omega_k\) 取成 \(e^{-\frac{2\pi i}{k}}\),这样计算结果就为原先倒数,之后唯一多的操作就只有再除以长度 \(n\),其他操作过程与 DFT 完全一致。我们可以定义一个函数同时完成 DFT 和 IDFT,用一个参数判断即可。
例如前面给到的递归版 FFT,就可以写成:
typedef complex<double> comp;
const comp i(0, 1);
const int Maxn = 1 << 20;
const double pi = acos(-1);
comp tmp[Maxn];
void DFT(comp *f, int n, int rev) { //rev=1:DFT rev=-1:IDFT
if(n == 1) return ;
for(int i = 0; i < n; i++) {
tmp[i] = f[i];
}
for(int i = 0; i < n; i++) {
if(i & 1) {//偶数放左边,奇数放右边
f[n / 2 + i / 2] = tmp[i];
}
else {
f[i / 2] = tmp[i];
}
}
comp *g = f, *h = f + n / 2;//递归求解
DFT(g, n / 2, rev);
DFT(h, n / 2, rev);//分治
//↑:分治
//↓:合并分治
comp cur(1, 0), step(sin(2 * pi / n), rev * sin(2 * pi / n));
//当前单位根为 cur,step 为两个单位根的差。
for(int k = 0; k < n / 2; k++) {
tmp[k] = g[k] + cur * h[k];
tmp[k + n / 2] = g[k] - cur * h[k];//推出的两个公式
cur *= step;
}
for(int i = 0; i < n; i++) {
f[i] = tmp[i];
}
}
同时,我们也可以得出非递归版的 FFT 代码:
void fft(comp y[], int len, int rev) {//rev=1:DFT rev=-1:IDFT
change(y, len);
for(int h = 2; h <= len; h <<= 1) {//枚举长度 len
comp step(cos(2 * pi / h), rev * sin(2 * pi / h));
for(int j = 0; j < len; j += h) {//枚举 0,s,2s,3s,...,n-2s(n-n)
comp w(1, 0);
for(int k = j, k < j + h / 2; k++) {//直接枚举 l_1+k
comp u = y[k];
comp t = w * y[k + h / 2];//l_1+k+s=l_2+k
y[k] = u + t;
y[k + h / 2] = u - t;//覆写
w *= step;
}
}
}
if(rev == -1) {//如果是 IDFT,除以长度 len
for(int i = 0; i < len; i++) {
y[i].x /= len;
}
}
}
6 模板代码
以 P3803 【模板】多项式乘法(FFT) 为例,给出代码:
#include <bits/stdc++.h>
using namespace std;
typedef complex<double> comp;
const comp i(0, 1);
const int Maxn = 6e6 + 5;
const double pi = acos(-1);
int rev[Maxn];//R(x)
int n, m, q = 1;
comp a[Maxn], b[Maxn];
void change(comp y[], int len) {
for(int i = 0; i < len; i++) {//求翻转后数组
rev[i] = rev[i >> 1] >> 1;
if(i & 1) {
rev[i] += (len >> 1);
}
}
for(int i = 0; i < len; i++) {//下标交换
if(i < rev[i]) {//保证只交换一次
swap(y[i], y[rev[i]]);
}
}
}
void fft(comp y[], int len, int rev) {//rev=1:DFT rev=-1:IDFT
change(y, len);
for(int h = 2; h <= len; h <<= 1) {//枚举长度 len
comp step(cos(2 * pi / h), rev * sin(2 * pi / h));
for(int j = 0; j < len; j += h) {//枚举 0,s,2s,3s,...,n-2s(n-n)
comp w(1, 0);
for(int k = j; k < j + h / 2; k++) {//直接枚举 l_1+k
comp u = y[k];
comp t = w * y[k + h / 2];//l_1+k+s=l_2+k
y[k] = u + t;
y[k + h / 2] = u - t;//覆写
w *= step;
}
}
}
}
int main() {
ios::sync_with_stdio(0);
cin >> n >> m;
for(int i = 0; i <= n; i++) {
double x;
cin >> x;
a[i].real(x);
}
for(int i = 0; i <= m; i++) {
double x;
cin >> x;
b[i].real(x);
}
q = 1;
while(q <= n + m) q <<= 1;
fft(a, q, 1);
fft(b, q, 1);
for(int i = 0; i <= q; i++) {
a[i] *= b[i];
}
fft(a, q, -1);
for(int i = 0; i <= n + m; i++) {
cout << (int)(a[i].real() / q + 0.5) << " ";
}
return 0;
}