畜中牲都不一定能理解的 FFT
前言
看了一上午的 FFT 竟然学会了。于是写下这篇来纪念。
期间涉及复平面的相关知识,我这个畜中牲竟然懂了,真是神奇,请不要望而却步,勇于面对,死磕一下总是好的。
FFT 中文名 快速傅里叶变换
OIer 经常拿它来解决高精度乘法的问题。朴素高精乘是 \(O(n^2)\) 的,只能解决两个 \(10^{10000}\) 级别的数字相乘,而 FFT 是 \(O(n\log n)\) 的,可以解决两个 \(10^{1000000}\) 级别的数乘起来的问题。
以下保证 \(\log_2n \in \N^*\)
What is FFT ?
快速傅里叶变换(FFT)是一种能在 \(O(n\log n)\) 的时间内将一个多项式从它的系数表示法转换成它的点值表示的算法。
什么是多项式的系数表示法和点值表示法
一个 \(n - 1\) 次的多项式 \(A(x) = a_0 + a_1 x^1 + a_2 x^2 + ... + a_{n - 1} x^{n - 1}\) ,这个就是系数表示法。
而一个 \(n - 1\) 次的多项式,代入 \(n\) 个不同的 \(x\) 值,将会得到 \(n\) 个对应的 \(y\) 值。那么如果我们知道这些数对 \((x_k, y_k)\) ,就可以计算得这个多项式中的每个 \(a\) 。所以这些数对就是点值表示法。
朴素傅里叶变换(DFT,FFT 的理论基础)
强大的傅里叶前辈发明了一种方法,将 \(x\) 取 \(x^n = 1\) 的 \(n\) 个复数解。
牛逼的 C++ 给了复数模板:
头文件 #include <complex>
定义 complex<double> x;
令 \(\omega_n^k = e^{i \frac{2k \pi}{n}} = \cos(\frac{2k \pi}{n}) + i \cdot \sin(\frac{2k \pi}{n})\) (即 \(x^n = 1\) 的 \(n\) 个复数解),并称 \(\omega\) 为单位根。
于是可以得到单位根的几个性质:
- \(\omega_{2n}^{2k} = \omega_{n}^{k}\) 显然,因为表示的都是一个数(
实际上你可以带进去计算一下); - \(\omega_n^{k + \frac{n}{2}} = -\omega_n^k\) 因为关于原点对称。
- \((\omega_n^k)^m = \omega_n^{mk}\) ,显然符合复数相乘的幅角相加法则。
为什么要选择这些来代入式子呢?
这就牵扯到 DFT 的优美性质了。
我们将 \((\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n - 1})\) 代入 \(A(x) = a_0 + a_1 x^1 + a_2 x^2 + ... + a_{n - 1} x^{n - 1}\) 。令 \(y_k = A(\omega_n^k)\),那么就会有 \(n\) 个 \(y\) 值 \((y_0, y_1, y_2, ..., y_{n - 1})\),再新来一个多项式 \(B(x) = y_0 + y_1 x^1 + y_2 x^2 + ... +y_{n - 1} x^{n - 1}\),将 \((\omega_n^0, \omega_n^{-1}, \omega_n^{-2}..., \omega_n^{-(n - 1)})\) (单位根的倒数)代入 \(B(x)\) ,令 \(z_k = B(\omega_n^{-k})\) 。
\(\sum_{i = 0}^{n - 1} (\omega_n^{j - k})^i\) 是可以求的。
- 当 \(j = k\) 时,原式 \(= n\);
- 当 \(j \neq k\) 时,原式 \(= \frac{(\omega_n^{j - k})^n - 1}{\omega_n^{j - k} - 1} = \frac{(\omega_n^n)^{j - k} - 1}{\omega_n^{j - k} - 1} = \frac{1 - 1}{\omega_n^{j - k} - 1} = 0\) (等比序列求和)
\(\therefore z_k = n a_k\)
\(\therefore a_k = \frac{z_k}{n}\)
结论
将 \((\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n - 1})\) 代入 \(A(x) = a_0 + a_1 x^1 + a_2 x^2 + ... + a_{n - 1} x^{n - 1}\) 得到 \((y_0, y_1, y_2, ..., y_{n - 1})\),将 \((\omega_n^0, \omega_n^{-1}, \omega_n^{-2}..., \omega_n^{-(n - 1)})\) 代入 \(B(x) = y_0 + y_1 x^1 + y_2 x^2 + ... +y_{n - 1} x^{n - 1}\) 得到 \((z_0, z_1, z_2, ..., z_{n - 1})\),那么对于任何一个 \(a_k\) ,都有 \(a_k = \frac{z_k}{n}\) 。
快速傅里叶变换(FFT)
虽然我们搞出了伟大的 DFT 的结论,但如果暴力代入还是 \(O(n^2)\) ,不能接受。
于是 FFT 油然而生。
数学证明
设\(A(x) = a_0 + a_1 x^1 + a_2 x^2 + ... + a_{n - 1} x^{n - 1}\)
先将 \(A(x)\) 的每一项按奇偶进行划分:
设两个多项式:
\(\therefore\)
假设 \(k < \frac{n}{2}\) ,将 \(x = \omega_n^k\) 代入。
那么对于 \(A(\omega_n^{k + \frac{n}{2}})\) :
如果我们知道 \(A_1(x)\) 和 \(A_2(x)\) 在 \((\omega_{\frac{n}{2}}^0, \omega_{\frac{n}{2}}^1, \omega_{\frac{n}{2}}^2, ..., \omega_{\frac{n}{2}}^{{\frac{n}{2}} - 1})\) 的点值表示,就可以 \(O(n)\) 求出 \(A(x)\) 在 \((\omega_n^0, \omega_n^1, \omega_n^2, ..., \omega_n^{n - 1})\) 的点值表示了。而 \(A_1(x)\) 和 \(A_2(x)\) 都是规模缩小一半的子问题,总时间复杂度 \(O(n\log n)\) 。分治条件是 \(n = 1\) 此时什么都不干,直接 return 。
#include <bits/stdc++.h>
using cpd = std::complex<double>;
#define rep(i, a, b) for(int i = (a); i <= (b); ++i)
#define il inline
const int N = 1e6 + 10;
const double pi = acos(-1.0);
il cpd omega(const int n, const int k) {
return cpd(cos(2.0 * k * pi / n), sin(2.0 * k * pi / n));
}
il void FFT(cpd *a, int n, bool inv) {//inv 的作用是标记是否要取倒数
if (n == 1) return;
static cpd tmp[N];
int m = n / 2;
rep(i, 0, m - 1)//按照奇偶分为两组
{
tmp[i] = a[2 * i];
tmp[i + m] = a[2 * i + 1];
}
rep(i, 0, n - 1) a[i] = tmp[i];
FFT(a, m, inv);//递归处理两个子问题
FFT(a + m, m, inv);
rep(i, 0, m - 1)
{
cpd x = omega(n, i);
if(inv) x = conj(x);
tmp[i] = a[i] + x * a[i + m];
tmp[i + m] = a[i] - x * a[i + m];
}
rep(i, 0, n - 1) a[i] = tmp[i];
}
cpd a[N];
int main() {
int n;
std::cin >> n;
rep(i, 0, n - 1)
{
double t;
std::cin >> t;
a[i].real(t);
}
FFT(a, n, true);
rep(i, 0, n - 1)
std::cout << a[i] << '\n';
}
递归真的超级慢,也超级难写。于是我们学一些优化。
FFT 优化
from 递归 to 非递归
在进行 FFT 时,我们要把各个系数不断分组并放到两侧,那么一个系数原来的位置和最终的位置有什么规律呢?
| 初始位置 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
|---|---|---|---|---|---|---|---|---|
| 第一轮后 | 0 | 2 | 4 | 6 | 1 | 3 | 5 | 7 |
| 第二轮后 | 0 | 4 | 2 | 6 | 1 | 5 | 3 | 7 |
| 第三轮后 | 0 | 4 | 2 | 6 | 1 | 5 | 3 | 7 |
可以发现,这你都能发现,原本位置为 a 的数,最后所在的位置是 a 二进制翻转得到的数,如 \(6 = (011)_2\) 最后的位置就是 \(3 = (110)_2\) ,\(1 = (001)_2\) 到了 \(5 = (100)_2\) 。
那么我们可以据此写出非递归版本 FFT :先把每个数放到最后的位置上,然后不断向上还原,同时求出点值表示。
#include <bits/stdc++.h>
using cpd = std::complex<double>;
#define rep(i, a, b) for(int i = (a); i <= (b); ++i)
#define il inline
const int N = 1e6 + 10;
const double pi = acos(-1.0);
int n;
cpd a[N], b[N], omg[N], inv[N];
il void init() {
rep(i, 0, n - 1)
{
omg[i] = cpd(cos(2 * i * pi / n), sin(2 * i * pi / n));
inv[i] = conj(omg[i]);
}
}
il void FFT(cpd *a, cpd *omg) {
int lim = 0;
while((1 << lim) < n) lim++;
rep(i, 0, n - 1)
{
int t = 0;
rep(j, 0, lim - 1)
if((i >> j) & 1)
t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
static cpd buf[N];
for(int l = 2; l <= n; l *= 2)
{
int m = l / 2;
for(int j = 0; j < n; j += l)
rep(i, 0, m - 1)
{
buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
}
rep(j, 0, n - 1) a[j] = buf[j];
}
}
int main() {
std::cin >> n;
init();
rep(i, 0, n - 1)
{
double t;
std::cin >> t;
a[i].real(t);
}
FFT(a, inv);
rep(i, 0, n - 1)
{
std::cout << a[i] << '\n';
}
}
蝴蝶优化
别跑呀,实际上很简单。
buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
我们发现 buf 数组的作用实际上是为了让这两行不互相影响,但是如果写成:
cpd t = omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - t
a[j + i] = a[j + i] + t
就可以抛弃 buf 数组了。
最终版本
#include <bits/stdc++.h>
using cpd = std::complex<double>;
#define rep(i, a, b) for(int i = (a); i <= (b); ++i)
#define il inline
const int N = 1e6 + 10;
const double pi = acos(-1.0);
int n;
cpd a[N], b[N], omg[N], inv[N];
il void init() {
rep(i, 0, n - 1)
{
omg[i] = cpd(cos(2 * i * pi / n), sin(2 * i * pi / n));
inv[i] = conj(omg[i]);
}
}
void FFT(cpd *a, cpd *omg) {
int lim = 0;
while((1 << lim) < n) lim++;
for(int i = 0; i < n; i++)
{
int t = 0;
for(int j = 0; j < lim; j++)
if((i >> j) & 1)
t |= (1 << (lim - j - 1));
if(i < t) swap(a[i], a[t]); // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}
for(int l = 2; l <= n; l *= 2)
{
int m = l / 2;
for(cpd *p = a; p != a + n; p += l)
for(int i = 0; i < m; i++)
{
cpd t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
int main() {
std::cin >> n;
init();
rep(i, 0, n - 1)
{
double t;
std::cin >> t;
a[i].real(t);
}
FFT(a, inv);
rep(i, 0, n - 1)
{
std::cout << a[i] << '\n';
}
}

浙公网安备 33010602011771号