FFT 学习笔记(自认为详细)
引入
什么是 \(\text{FFT}\) ?
反正我看到 \(\text{wiki}\) 上是一堆奇怪的东西。
快速傅里叶变换(英语:Fast Fourier Transform, FFT),是快速计算序列的离散傅里叶变换(DFT)或其逆变换的方法。傅里叶分析将信号从原始域(通常是时间或空间)转换到频域的表示或者逆过来转换。FFT会通过把DFT矩阵分解为稀疏(大多为零)因子之积来快速计算此类变换。—— \(\text{wikipedia}\)
反正我没脑子我看不懂。
对我来说,\(\text{FFT}\) 就是能把多项式乘法从 \(O(n^2)\) 变成 \(O(n\log n)\) 的神仙玩意。
正文
系数表示法和点值表示法
对于系数表示法,就是用多项式的系数来表示这个多项式。
比如说:
那么对于点值表示法,相对应的就是用该函数上的若干个点表示多项式。
学过小学数学的同学们一定知道:\(n+1\) 个点确定一个 \(n\) 次多项式。
证明的话可以考虑数学归纳法。/xyx
同样举一个例子,点值表示法是这样的:
上面讲到要把系数表示法转换成点值表示法。那么这是为什么呢?
下面就先来展示一下点值表示法的多项式乘法:
复数
\(复数 = 实数 + 虚数\)
实在一点吧,直接上干货,我们定义 :
这样我们就可以表示我们在实数范围内不能表示的数了。
那么如何表示一个复数呢:
接着我们把 \(Num=a+bi\) 看成一个函数,把 \(a\) 和 \(b\) 分别对应 \(x\) 轴和 \(y\) 轴。
就可以得到复数平面,大概长这样:
其中横坐标是实数轴,纵坐标是虚数轴,这样就可以把每个虚数看为一个向量了。
对应的,虚数可以用普通坐标和极坐标表示:
下面给出两个复数相乘的意义:
\(\tt DFT\) (离散傅里叶变换)
现在已经介绍完了点值表示法和复数的相关知识,接下来就是干货部分了。
上面我们已经通过这样的例子说明了点值表示法算多项式乘法的方便。
接下来我们来看怎么先把多项式从系数表示法转换为点值表示法,这种过程叫 \(\text{DFT}\) 。
所谓的点值表示法,也就是在 \(n\) 多项式上取 \(n+1\) 个点,来进行表示。
形式化的,可以表示成这样:
然后可以惊喜的发现,随便带几个 \(x_i\) 进去在算算 \(F(x_i)\) 就好了。
但是如果你小学毕业了,你就可以发现这样的话不如直接 \(O(n^2)\) 暴力。
所以该怎么办?
我们猜想是否存在一些 \(x\) 使得 \(x^n\ (n\in \tt Z^+)\) 的结果都是 \(1\) 。
这看上去是一个非常好的思路,但是这样的数有多少个呢?
我能脱口说出两个 \(1\) 和 \(-1\) ,想一想可以发现其实 \(i\) 和 \(-i\) 也都可以。
但是经过认真思考(看题解)可以发现下图的单位圆上所有的点都满足条件。
为了方便,我们在取这 \(n\) 个点时会把这个单位圆平分。
我们从 \((1, 0)\) 这个点开始,按照逆时针的方向从 \(0\) 开始进行编号,形如 \(\omega_n^k\) 。
其中 \(n\) 表示一共选择了 \(n\) 个点,\(k\) 表示当前点的编号。
由我们之前介绍的复数乘法的 模长相乘,度数相加:
并且结合单位圆的性质(所有的点到原点的距离为 \(1\))。
可以得到由 \(\omega_n^1\) 转换到 \(\omega_n^k\) 的公式:
我们称 \(\omega_n^1\) 为 \(n\) 次单位根。
所以可以发现,我们直接带入 \(\omega_n^i\) 就可以了。
单位根的一些有用的性质
在了解一切的性质之前,我们要先知道单位根 \(\omega_n^i\) 如何表示:
这东西的证明你直接照着单位圆上画一个点然后三角函数入门知识即可。
性质一
证明的话直接照着上面给出的式子套即可,然后发现可以约分。
那我认为进一步的可以得到:
很显然不过好像没有什么大用。
性质二
证明的话稍微写一下吧:
都化成这一步了就不在进行下一步证明,还看不懂的建议重修初中数学。
性质三
比较憨,我就不讲为什么了。
\(\tt FFT\) (快速傅里叶变换)
他来了,他来了,等到现在他终于来了。。。。
之前讲到我们直接带入 \(\omega_n^i\) 来计算点值。
是的,我认为这种方法高效,巧妙,逼格高,体现了人类智慧。
但是等等,虽然算系数的过程免掉了,但是对于每一个 \(\omega_n^i\) 我们还是要 \(O(n)\) 算结果啊。
然后我搬来搬手指算了一下,发现一共有 \(n\) 个 \(\omega_n^i\) 的值,然后就又 \(O(n^2)\) 了。
所以我们该怎么办?
认真地看看题解,发现可以从分治的角度入手。
注意:以下的内容保证 \(n\) 为 \(2\) 的整数次方。
我们设一个多项式:
然后想办法把 \(F(x)\) 分成两个部分。
这里采用的方法是按照 \(F(x)\) 下标的奇偶性分成两个部分。
接下来我们发现拆出来的这两个多项式的结构是一模一样的。
我们再分别设这两个多项式为 \(F_1(x)\) 和 \(F_2(x)\) 。
发现这样的系数不连续,没有那么完美,于是我们再变化一下。
此时看可以发现这样的形式非常的优美。
接下来就是直接带入 \(\omega_n^i\) 的操作了。
我们接着设 \(k<\frac{n}{2}\) 然后把 \(\omega_n^k\) 直接带入。
第一步直接带入,有问题的话小学建议重修。
第二步的话我之前写过,公式是这样的:
当然,在这里运用是具有普遍性的,有问题的话直接推一下。
至于第三步,直接算比例我认为会更加快速一些。
对于 \(F(\omega_n^{k+\frac{n}{2}})\) 直接带入:
每一步一一介绍比较麻烦,大家直接手头一下或者翻翻前面的公式。
观察第一个式子和第二个式子,发现唯一不一样的地方就是符号了。
然后直接分治求解即可,时间复杂度 \(O(n\log n)\) 。
\(\tt IFF\) (快速傅里叶逆变换)
就是把点值表示法转换成为我们要的系数表示法。
这里给出结论,证明的话属实比较恶心,所以我就不证明了。
一个多项式在分治的过程中乘上单位根的共轭复数,分治完的每一项除以 \(n\) 即为原多项式的每一项系数
也就是再做一遍 \(\tt FFT\) 输出时每一位除以 \(n\) 就可以了。
代码实现及其优化
Code 复数类型封装
struct cp {
double x, y;
cp (double xx = 0, double yy = 0) {x = xx; y = yy;};
friend cp operator +(cp p, cp q) {return cp(p.x + q.x, p.y + q.y);}
friend cp operator -(cp p, cp q) {return cp(p.x - q.x, p.y - q.y);}
friend cp operator *(cp p, cp q) {return cp(p.x * q.x - p.y * q.y, p.y * q.x + p.x * q.y);}
}a[N], b[N];
Code 无优化
不是我写的代码,反正就是照着之前的公式模拟,看看就好了。
点击查看代码
#include<complex>
#define cp complex<double>
void fft(cp *a, int n, int inv) //inv是取共轭复数的符号
{
if (n == 1)return;
int mid = n / 2;
static cp b[MAXN];
fo(i, 0, mid - 1)b[i] = a[i * 2], b[i + mid] = a[i * 2 + 1];
fo(i, 0, n - 1)a[i] = b[i];
fft(a, mid, inv), fft(a + mid, mid, inv); //分治
fo(i, 0, mid - 1)
{
cp x(cos(2 * pi * i / n), inv * sin(2 * pi * i / n)); //inv取决是否取共轭复数
b[i] = a[i] + x * a[i + mid], b[i + mid] = a[i] - x * a[i + mid];
}
fo(i, 0, n - 1)a[i] = b[i];
}
cp a[MAXN], b[MAXN];
int c[MAXN];
fft(a, n, 1), fft(b, n, 1); //1系数转点值
fo(i, 0, n - 1)a[i] *= b[i];
fft(a, n, -1); //-1点值转系数
fo(i, 0, n - 1)c[i] = (int)(a[i].real() / n + 0.5); //注意精度
注意:\(\tt FFT\) 之前要先把 \(n\) 调成 \(2\) 的整数次幂。
很显然上面的那个是连模板题都过不了的。
所以在这里我们才需要去考虑怎么去优化 \(\tt FFT\) 。
观察一下原序列和反转后的序列,需要求的序列实际是原序列下标的二进制反转!
因此我们对序列按照下标的奇偶性分类的过程其实是没有必要的。
这样我们可以 \(O(n)\) 的利用某种操作得到我们要求的序列,然后不断向上合并就好了。
—— \(\tt luogu\) 某题解
Code 有优化,可过
点击查看代码
#include <bits/stdc++.h>
#define file(a) freopen(a".in", "r", stdin), freopen(a".out", "w", stdout)
#define Enter putchar('\n')
#define quad putchar(' ')
#define N 3000005
namespace IO {
template <class T>
inline void read(T &a);
template <class T, class ...rest>
inline void read(T &a, rest &...x);
template <class T>
inline void write(T x);
}
struct cp {
double x, y;
cp (double xx = 0, double yy = 0) {x = xx; y = yy;};
friend cp operator +(cp p, cp q) {return cp(p.x + q.x, p.y + q.y);}
friend cp operator -(cp p, cp q) {return cp(p.x - q.x, p.y - q.y);}
friend cp operator *(cp p, cp q) {return cp(p.x * q.x - p.y * q.y, p.y * q.x + p.x * q.y);}
}a[N], b[N];
const double pi = acos(-1.0);
int n1, n2, n, rev[N], c[N];
inline void FFT(cp *a, int n, int inv) {
int bit = 0;
while ((1 << bit) < n) bit++;
for (int i = 1; i < n; ++i) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
if (i < rev[i])
std::swap(a[rev[i]], a[i]);
}
for (int mid = 1; mid < n; mid <<= 1) {
cp temp(cos(pi / mid), inv * sin(pi / mid));
for (int i = 0; i < n; i += mid * 2) {
cp omega(1,0);
for (int j = 0; j < mid; ++j, omega = omega * temp) {
cp x = a[i + j], y = omega * a[i + j + mid];
a[i + j] = x + y;
a[i + j + mid] = x - y;
}
}
}
}
signed main(void) {
// file("P3803");
IO::read(n1, n2);
n = std::max(n1, n2);
for (int i = 0, num; i <= n1; ++i) IO::read(num), a[i].x = num;
for (int i = 0, num; i <= n2; ++i) IO::read(num), b[i].x = num;
n = n1 + n2;
for (int i = 0; i <= 30; ++i)
if ((1 << i) > n) {
n = (1 << i);
break;
}
FFT(a, n, 1); FFT(b, n, 1);
for (int i = 0; i < n; ++i) a[i] = a[i] * b[i];
FFT(a, n, -1);
for (int i = 0; i <= n1 + n2; ++i)
c[i] = (int)(a[i].x / n + 0.5);
for (int i = 0; i <= n1 + n2; ++i)
IO::write(c[i]), quad;
Enter;
}
namespace IO {
template <class T>
inline void read(T &a) {
T s = 0, t = 1;
char c = getchar();
while ((c < '0' || c > '9') && c != '-')
c = getchar();
if (c == '-')
c = getchar(), t = -1;
while (c >= '0' && c <= '9')
s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
a = s * t;
}
template <class T, class ...rest>
inline void read(T &a, rest &...x) {
read(a);
read(x...);
}
template <class T>
inline void write(T x) {
if (x == 0) putchar('0');
if (x < 0) putchar('-'), x = -x;
int top = 0, sta[55] = {0};
while (x)
sta[++top] = x % 10, x /= 10;
while (top)
putchar(sta[top] + '0'), top--;
return ;
}
}
在这里推荐 某知乎专栏 ,把 \(\tt FFT\) 优化讲的很清楚。
\(\tt NTT\) 还是会看的,但是 \(\tt FFT\) 把我给些虚脱了。。。