快速傅里叶变换学习笔记(更新中)
快速傅里叶变换(FFT)学习笔记
简介
快速傅里叶变换($ \rm Fast\ Fourier\ Transformation $), 简称 \(\rm FFT\), 用于在 $ \Theta(n\log n) $ 时间内求两个多项式的乘积.
快速数论变换($ \rm Fast\ Number\ Theoretic\ Transforms$), 简称 \(\rm NTT\), 用于在 $ \Theta(n\log n) $ 时间内求两个多项式的乘积, 系数对 \(p\) 取模.
前置技能
卷积
(注: 以下所有多项式均只包含 $ x $ 一个变量)
一个 $ n - 1 $ 次 $ n $ 项式 $ f(x) $ 可以表示为 $ f(x) = \sum_{i=0}^{n - 1} a_ix^i $.
设有多项式 $ f(x) = \sum_{i=0}^{n - 1} a_i x^i $ 和 $ g(x) = \sum_{i = 0}^{n - 1} b_ix^i $, 则有:
我们称 $ (f \cdot g)(x) $ 为 $ f(x) $ 和 $ g(x) $ 的卷积.
复数
定义
定义 $ i = \sqrt{-1} $, 则出现了形如 $ a+bi(a,b\in \mathbb{R}) $ 的数.
$ a + bi $ 可以被表示为复平面上 \((a, b)\) 的点, 也可以表示为向量 \((a, b)\).
复数的模(长):
$|a + bi| = \sqrt{a^2 + b^2} $.
为该复数所表示的点在复平面上到原点的距离.
复数的幅角:
一个复数的幅角为该数在复平面上与实轴正半轴的夹角(逆时针)记作 $ arg(a + bi) $, 显然复数的幅角有无穷多个, 每个幅角都相差 \(2\pi\), 若 $ a + bi $ 的幅角 $ \theta$ 满足 $ -\pi \le \theta \le \pi $, 则称 $ \theta $ 是 $ a + bi $ 的幅角主值.
\(0\) 的幅角不确定.
运算
加(减)法:
$ (a + bi) + (c + di) = (a + c) + (b + d)i h(x) = f(x)\cdot g(x) $.
乘法:
$ (a + bi)(c + di) = ac + adi + bci - bd = (ac - bd) + (bc + ad)i $.
如图, $ A \cdot B = C $, 复数的乘法有以下性质: 模长相乘, 幅角相加.
模长相乘可以以如下方法证明:
设 $ z_1 = a + bi $, $ z_2 = c + di $, $ z_3 = z_1z_2 = (ac - bd) + (ad + bc) i $;
有:
$|z_1|\cdot |z_2| = \sqrt{a^2 + b^2}\cdot \sqrt{c^2 + d^2} = \sqrt{(a^2 + b2)(c2 + d^2)} = \sqrt{a2c2 + a2d2 + b2c2 + b2d2} $
\(|z_3| = \sqrt{(ac - bd) ^ 2 + (ad + bc) ^ 2} = \sqrt{a^2c^2 + a^2d^2 + b^2c^2 + b^2d^2} = |z_1|\cdot |z_2|\)
共轭:
$ \overline{a + bi} = a - bi $, 易证任何复数乘上它的共轭都为实数, $ (a + bi)(a - bi) = a ^ 2 + b ^2 $.
除法:
欧拉公式
求证 \(e^{i\theta} = \cos \theta + i \sin \theta\);
设
可得
所以, $ f(\theta) $ 是常数.
求 $ f(0) $, 得 $ f(\theta) = f(0) = \frac{1}{1 + 0} = 1 $.
所以, $ e^{i\theta} = \cos \theta + i \sin \theta $;
证毕.
单位根
定义
定义 $ n(n\ge 1) $ 次单位根的 $ n $ 次幂等于 $ 1 $.
显然单位根的模长等于 $ 1 $, 否则怎么乘都不会是 $ 1 $.
所以, 单位根都在单位圆上.
考虑设单位根幅角为 $ \alpha $, 则有 $ n\alpha = 2k\pi, (k\in \mathbb N) $.
解得:
所以, 单位根 $ n $ 等分单位圆.
注意到, 当 $ k = n $ 时, 进入了循环, 所以只有 \(n\) 个单位根.
对于 $ n $ 次的幅角为 $ \frac{2\pi}{n}k $ 的单位根我们记作 $ \omega_n^k $.
显然, $ \omega_n^k = \omega_n^{k\bmod n} $.
得到幅角 $ \alpha $ 后, 我们可以很轻易的得到该单位根所对应的数.
引理
如上图, $ A'B = OA \cdot \sin \alpha = \sin \alpha $;
$ OB = OA \cdot \cos \alpha = \cos \alpha $.
所以点 $ A' $ 所代表的数为 $ \cos \alpha + i\sin \alpha $.
所以
由上式, 有:
引理 1:
$ \omega_nj\omega_nk = (\omega_n1)j(\omega_n1)k = (\omega_n1) =\omega_n^{j + k} $;
引理 2:
同上, 显然有 $ (\omega_nj)k = \omega_n^{kj} $;
引理 3 (消去引理):
$ \omega_{dn}^{dk} = e^{i\frac{2\pi}{dn}dk} = e^{i\frac{2\pi}{n}k} = \omega_n^k $;
引理 4 (折半引理):
由消去引理可得 $ \omega_{2n}^n = \omega_2^1 = -1 $;
$ \omega_{2n}^{k + n} = \omega_{2n}k\omega_{2n}n = -\omega_{2n}^k $;
引理 5:
当 $ n $ 为偶数时.
根据消去引理, 有:
$ (\omega_nk)2 = \omega_n^{2k} = \omega^k_{\frac{n}{2}} $
快速傅里叶变换
离散傅立叶变换(DFT)
多项式的点值表示
定义
显然, $ n + 1 $ 个点确定一个 $ n $ 项式.
所以我们可以用形如 $ (x_0, y_0),(x_1,y_1),(x_2,y_2),\dots,(x_n,y_n) $ 来表示一个多项式.
点值表示下的卷积
显然, 设有 $ (x_0, y_0),(x_1,y_1),(x_2,y_2),\dots,(x_n,y_n) $ 表示$ f(x) $.
有 $ (x_0, y_0'),(x_1,y_1'),(x_2,y_2'),\dots,(x_n,y_n') $ 表示$ g(x) $.
则由于 $ h(x) = (f\cdot g)(x) = f(x)\cdot g(x) $, 很自然的就有:
$ (x_0, y_0\cdot y_0'),(x_1, y_1\cdot y_1'),(x_2, y_2\cdot y_2'),\cdots,(x_n, y_n\cdot y_n') $ 表示 $ h(x) $.
我们就有很快的 $ \Theta(n) $ 算法来解决已经求出点值表示的两个多项式相乘.
需要注意的是, 由于 $ h $ 理论上有 $ 2n + 1 $ 项, 所以需要在先前多造几组点.
求值
(根据系数表示求点值表示叫做求值)
(以下所有 $ n $ 都保证 $ n = 2^t, t\in \mathbb N $, 实际情况可以补一些系数为 \(0\) 的高次项来凑齐)
设:
我们要求出其点值表示.
于是我们想(luan)到(gao), 给 f 数组代入单位根, 并把奇数项和偶数项分开.
设 $ f_0(x) = \sum_{i = 0} ^ {\frac{n}{2} - 1} a_{2i} x^i $
$ f_1(x) = \sum_{i = 0} ^ {\frac{n}{2} - 1} a_{2i + 1} x^i $
显然有: \(f(x) = f_0(x^2) + xf_1(x^2)\).
我们代入 \(\omega_n^k\), 得:
容易得到:
(其中第二行用了上面的引理 5)
然而我们还有 \(\frac{n}{2}\) 以上的点值表示还没有求.
又设 \(1 \le k < \frac{n}{2}\).
有:
(其中第二行用到了引理 2, 第三行用到了 \(\omega_n^k = \omega_n^{k\bmod n}\), 第四行用到了消去引理, 第五行用到了折半引理)
我们发现, 两个式子都包含了 \(f_0\left(\omega_{\frac{n}{2}}^k\right), \omega_n^k f_1\left(\omega_{\frac{n}{2}}^k\right)\) , 原问题被分解成两个更小的子问题, 采用分治即可, 时间复杂度为 \(\Theta(n \log n)\).
代码实现
void DFT(complex f[], int len) {
if(!len) return;
complex fl[len + 5], fr[len + 5];
for(int i = 0; i < len; i++) {
fl[i] = f[i << 1];
fr[i] = f[i << 1 | 1];
}
DFT(fl, len >> 1, flag);
DFT(fr, len >> 1, flag);
complex ur(cos(pi / len), sin(pi / len)), tmp(1, 0);
for(int i = 0; i < len; i++) {
f[i] = fl[i] + tmp * fr[i];
f[i + len] = fl[i] - tmp * fr[i];
tmp *= ur;
}
}
离散傅里叶逆变换(IDFT)
将插值转化成矩阵的形式
(根据点值表示求系数表示叫做插值)
根据暴力求值的公式, 我们可以推出:
求该矩阵的逆
可以发现这是个范德蒙矩阵
设该矩阵为 \(D\).
前人已经告诉我们了逆矩阵:
虽然我不会找逆矩阵, 但是我可以证明一下这个逆矩阵: (我太菜了)
于是设 \(E = DD^{-1}\)
当 \(j \not =k\) 时, 根据等比数列求和公式:
当 \(j = k\) 时:
所以, \(E\) 为单位矩阵.
代码实现
void IDFT(complex f[], int len, int flag) {
if(!len) return;
complex fl[len + 5], fr[len + 5];
for(int i = 0; i < len; i++) {
fl[i] = f[i << 1];
fr[i] = f[i << 1 | 1];
}
IDFT(fl, len >> 1, flag);
IDFT(fr, len >> 1, flag);
complex ur(cos(pi / len), -sin(pi / len)), tmp(1, 0);
for(int i = 0; i < len; i++) {
f[i] = fl[i] + tmp * fr[i];
f[i + len] = fl[i] - tmp * fr[i];
tmp *= ur;
}
}
简化代码
发现 IDFT 的代码其实和 DFT 相差不大.
发现:
我们需要求:
发现 DFT 为:
然后发现这个式子和 DFT 要求的式子很像, 只不过是把 \(y\) 当成系数, 重新做一遍求值.
可以发现 \(\omega_n^{-1}\) 和 \(\omega_n^1\) 在各方面区别不大, 传一个参数来区分即可.
所以我们就只需要写一个 DFT 就可以了.
递归版代码实现
#include <bits/stdc++.h>
using std::cin;
using std::cout;
typedef double f64;
const f64 pi = acos(-1);
const f64 eps = 1e-5;
const int MAXN = 1e5 + 5;
struct complex {
f64 real, imag;
complex(f64 rpt = 0, f64 ipt = 0) {real = rpt; imag = ipt;}
complex operator + (const complex &x) const {
complex ans(real + x.real, imag + x.imag);
return ans;
}
complex operator - (const complex &x) const {
complex ans(real - x.real, imag - x.imag);
return ans;
}
complex operator * (const complex &x) const {
complex ans(real * x.real - imag * x.imag, real * x.imag + imag * x.real);
return ans;
}
complex operator += (const complex &x) {return *this = *this + x;}
complex operator -= (const complex &x) {return *this = *this - x;}
complex operator *= (const complex &x) {return *this = *this * x;}
} fa[MAXN], fb[MAXN];
int n, m;
void FFT(complex f[], int len, int flag) {
if(!len) return;
complex fl[len + 5], fr[len + 5];
for(int i = 0; i < len; i++) {
fl[i] = f[i << 1];
fr[i] = f[i << 1 | 1];
}
FFT(fl, len >> 1, flag);
FFT(fr, len >> 1, flag);
complex ur(cos(pi / len), sin(pi / len) * flag), tmp(1, 0);
for(int i = 0; i < len; i++) {
f[i] = fl[i] + tmp * fr[i];
f[i + len] = fl[i] - tmp * fr[i];
tmp *= ur;
}
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; i++)
scanf("%lf", &fa[i].real);
for(int i = 0; i <= m; i++)
scanf("%lf", &fb[i].real);
int len = 1;
for(n += m; len <= n; len <<= 1);
FFT(fa, len >> 1, 1); FFT(fb, len >> 1, 1);
for(int i = 0; i < len; i++)
fa[i] *= fb[i];
FFT(fa, len >> 1, -1);
for(int i = 0; i <= n; i++) printf("%.0lf ", fabs(fa[i].real) / len);
return 0;
}
蝴蝶算法
发现规律
我们发现, 直接用递归写, 常数非常大.
$ 0, 1, 2, 3, 4, 5, 6, 7 $
$ 0, 2, 4, 6|1, 3, 5, 7 $
$ 0, 4|2, 6|1, 5|3, 7 $
$ 0|4|2|6|1|5|3|7 $
发现最后的数的二进制表示是:
\(000, 100, 110, 001, 101, 011, 111\)
是原来的序号的二进制反序.
所以先求出最后的数组, 在原数组上迭代就可以了.
迭代实现
void FFT(complex f[], int len, int flag) {
for(int i = 0; i < len; i++)
if(i < rev[i]) std::swap(f[i], f[rev[i]]);
for(int i = 1; i < len; i <<= 1) {
complex ur(cos(PI / i), flag * sin(PI / i));
for(int j = 0; j < len; j += (i << 1)) {
complex tmp(1, 0);
for(int k = 0; k < i; k++, tmp *= ur) {
complex fr = f[i + j + k], fl = f[j + k];
f[j + k] = fl + tmp * fr;
f[i + j + k] = fl - tmp * fr;
}
}
}
}
代码实现
#include <bits/stdc++.h>
using std::cin;
using std::cout;
typedef double f64;
const f64 PI = acos(-1);
const int MAXN = 4e6 + 5;
struct complex {
f64 real, imag;
complex(f64 rpt = 0, f64 ipt = 0) {real = rpt; imag = ipt;}
complex operator + (const complex &x) const {
return complex(real + x.real, imag + x.imag);
}
complex operator - (const complex &x) const {
return complex(real - x.real, imag - x.imag);
}
complex operator * (const complex &x) const {
return complex(real * x.real - imag * x.imag, real * x.imag + imag * x.real);
}
complex operator += (const complex &x) {return *this = *this + x;}
complex operator -= (const complex &x) {return *this = *this - x;}
complex operator *= (const complex &x) {return *this = *this * x;}
} fa[MAXN], fb[MAXN];
int n, m, rev[MAXN];
void FFT(complex f[], int len, int flag) {
for(int i = 0; i < len; i++)
if(i < rev[i]) std::swap(f[i], f[rev[i]]);
for(int i = 1; i < len; i <<= 1) {
complex ur(cos(PI / i), flag * sin(PI / i));
for(int j = 0; j < len; j += (i << 1)) {
complex tmp(1, 0);
for(int k = 0; k < i; k++, tmp *= ur) {
complex fr = f[i + j + k], fl = f[j + k];
f[j + k] = fl + tmp * fr;
f[i + j + k] = fl - tmp * fr;
}
}
}
}
int read() {
int x = 0; char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
void write(int x) {
if(x / 10) write(x / 10);
putchar(x % 10 + '0');
}
int main() {
n = read(); m = read();
for(int i = 0; i <= n; i++) fa[i].real = (f64)read();
for(int i = 0; i <= m; i++) fb[i].real = (f64)read();
int len = 1, maxBit = 0;
for(n += m; len <= n; len <<= 1, maxBit++);
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (maxBit - 1));
FFT(fa, len, 1); FFT(fb, len, 1);
for(int i = 0; i < len; i++)
fa[i] *= fb[i];
FFT(fa, len, -1);
for(int i = 0; i <= n; i++)
write((int)(fabs(fa[i].real) / len + 0.5)), putchar(' ');
return 0;
}
快速数论变换
原根与阶
阶
最小的 $ t $ 使得 $ a^t \equiv 1 \pmod p$, 则 \(\delta_p (a) = t\), 称 $ t $ 为 \(a\) 对模 $ p $ 的阶.
原根
若 $ \delta_p(a) = \varphi(p) $, 则 \(a\) 是 \(p\) 的一个原根.
求原根
原根通常很小, 模数通常不会超过 \(10^9\), 暴力枚举即可.
常见模数 \(998244353 = 7 \times 17\times 2^{13} + 1\) 的原根为 \(3\).
快速数论变换
根据相关定理, 可以证明, 模 \(p\) 意义下, \(e^{\frac{-2\pi i}{n}} \equiv g^{\frac{p - 1}{n}} \pmod p\).
怎么证明? 我也不会.
代码实现
#include <bits/stdc++.h>
typedef long long ll;
const ll mod = 7 * 17 * (1 << 23) + 1;
const ll g = 3, gInv = 332748118;
const ll MAXN = 4e6 + 5;
ll n, m, rev[MAXN], fa[MAXN], fb[MAXN];
ll qPow(ll a, ll b) {
ll ans = 1, base = a;
while(b) {
if(b & 1) ans = ans * base % mod;
base = base * base % mod;
b >>= 1;
}
return ans;
}
void FFT(ll f[], int len, int flag) {
for(int i = 0; i < len; i++)
if(i < rev[i]) std::swap(f[i], f[rev[i]]);
for(int i = 1; i < len; i <<= 1) {
ll ur = qPow(flag ? g : gInv, (mod - 1) / (i << 1));
for(int j = 0; j < len; j += (i << 1)) {
ll tmp = 1;
for(int k = 0; k < i; k++, tmp = tmp * ur % mod) {
ll fr = f[i + j + k], fl = f[j + k];
f[j + k] = (fl + tmp * fr % mod) % mod;
f[i + j + k] = (fl - tmp * fr % mod + mod) % mod;
}
}
}
}
ll read() {
ll x = 0; char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
void write(ll x) {
if(x / 10) write(x / 10);
putchar(x % 10 + '0');
}
int main() {
n = read(); m = read();
for(int i = 0; i <= n; i++) fa[i] = read();
for(int i = 0; i <= m; i++) fb[i] = read();
int len = 1, maxBit = 0;
for(n += m; len <= n; len <<= 1, maxBit++);
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (maxBit - 1));
FFT(fa, len, 1); FFT(fb, len, 1);
for(int i = 0; i < len; i++)
fa[i] = fa[i] * fb[i] % mod;
FFT(fa, len, 0);
ll lenInv = qPow(len, mod - 2);
for(int i = 0; i <= n; i++)
write(fa[i] * lenInv % mod), putchar(' ');
return 0;
}
例题
题目
输入两个数 \(a, b\) 求 \(a\times b\). \(0\le a,b \le 10^{100000}\).
分析
发现普通高精乘的复杂度为 \(\Theta((\lg n)^2)\) 显然过不了这题.
可以把两个数看成两个多项式相乘, 最后处理一下进位即可.