多项式乘法 快速傅里叶变换(FFT)
前置知识1
1.多项式:一个以\(x\)为变量的多项式定义在一个代数域\(F\)上,将函数\(A(x)\)表示为形式和:
2.多项式的系数表示法;即由多项式的系数组成的向量 \(a\) $ = (a_0, \ a_1, \ ... \ a_{n - 1})$
2.多项式的点值表示法:即一个由\(n\)的点组成的集合
其中满足对于\(\forall \ i \ \in [0, \ n - 1]\)
3.点值表示法的多项式乘法:
已知两个多项式的点值表示:
则两式相乘后的点值表示为:
4.;离散傅里叶变换(DFT):我们称向量 \(y\) \(\ = \ (y_0, \ y_1, \ ... \ y_{n - 1})\) 为向量 \(a\) \(= (a_0, \ a_1, \ ... \ a_{n - 1})\)的离散傅里叶变换,记为 \(y\) \(= DFT(\) \(a\) \()\), 同时,\(a\) \(= DFT^{-1}(\) \(y\) \()\)
快速傅里叶变换(FFT):
可以在\(O(n \ log \ n)\)的时间复杂度内求解两个次数界为\(n\)的多项式的乘法。算法主要由三步构成:
1.求值\(O(n \ log \ n)\):将多项式由系数表示法转化为点值表示法,即\(DFT\)。
2.点值乘法\(O(n)\):将两个多项式的点值相乘,得到所求多项式的点值表达式。
3.插值\(O(n \ log \ n)\):将所求多项式由点值表示法转变为系数表示法,即\(DFT^{-1}\)。
一个次数界为\(n\)和\(m\)的多项式相乘,需要\(n + m\)个点的点值表达式,朴素的求法为\(O(n^2)\),显然\(x\)的取值是任意的,我们可以通过选取特殊的\(x\)来降低求值的复杂度。
n次单位复数根:
是指满足\(w^n = 1\)的复数\(w\),共有n个,分别为
根据欧拉公式可得:
n次单位复数根有几个性质:
性质1: \(w_{pk}^{pn} = w_{k}^{n}\)
性质2: \(w_n^{k + \frac{n}{2}} = -w_n^k\)
性质3: 对任意\(n \geq 1\)和不能被\(n\)整除的非负整数\(k\),有:
FFT的过程:
1.求值
(以下假定n为2的整数次幂)
对于多项式
将其按照奇偶项拆分:
则:
求\(A(x)\)在\(w_n^0, w_n^1, ..., w_n^{n - 1}\)处的值,即为求\(A_1(x)\)和\(A_2(x)\)在\((w_n^0)^2, (w_n^1)^2, ..., (w_n^{n - 1})^2\)处的值。
根据性质1可得:\((w_n^k)^2 = w_{n/2}^k\),因此\((w_n^0)^2, (w_n^1)^2, ..., (w_n^{n - 1})^2\)这n个单位复数根仅是由\(n / 2\)个不同的值组成的,我们可以先求出\(w_{n / 2}^0, ..., w_{n / 2}^{n / 2 - 1}\)这\(n / 2\)个值,再根据性质2,求出后一半的值,于是问题的规模缩小了一半,使用分治的方法即可在\(O(n \ log \ n)\)的时间内求值。
2.插值
将\(DFT\)的过程写成矩阵乘法的形式,则可得
其中\(V_n\)为\(x = w_n\)时的范德蒙德矩阵。则:
根据范德蒙德矩阵的性质,易得\(V_n^{-1}\)的\((i, j)\)处的元素为\(w_n^{-ij} \ / \ n\) 则可知:
式子与上面求值的\(DFT\)相似,通过类似方法也可以在\(O(n \ log \ n)\)的时间内插值。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> comp;
void FFT(int M, vector <comp> &a, int typ) {
if(M == 1) return ;
vector <comp> a1(M / 2 + 2), a2(M / 2 + 2);
int cnt = 0;
for(int i = 0; i <= M; i += 2) {
a1[cnt] = a[i]; a2[cnt++] = a[i + 1];
}
FFT(M >> 1, a1, typ); FFT(M >> 1, a2, typ);
comp wn(cos(PI * 2 / M), typ * sin(PI * 2 / M)), w(1, 0);
for(int i = 0; i < (M >> 1); i++) {
a[i] = a1[i] + w * a2[i];
a[i + (M >> 1)] = a1[i] - w * a2[i];
w = w * wn;
}
}
int n, m, M;
int main() {
// freopen("data.in", "r", stdin);
cin >> n >> m;
M = 1; while(M <= n + m) M <<= 1;
vector <comp> a(M + 2), b(M + 2);
for(int i = 0; i <= n; i++) {
double x; cin >> x;
a[i] = complex <double> (x, 0);
}
for(int i = 0; i <= m; i++) {
double x; cin >> x;
b[i] = complex <double> (x, 0);
}
int M = 1; while(M <= n + m) M <<= 1;
FFT(M, a, 1); FFT(M, b, 1);
vector <comp> c(M + 2);
for(int i = 0; i <= M; i++) {
c[i] = a[i] * b[i];
}
FFT(M, c, -1);
for(int i = 0; i <= n + m; i++) {
cout << (int)(c[i].real() / M + 0.5) << " ";
}
return 0;
}