【总结】对FFT的理解 / 【洛谷 P3803】 【模板】多项式乘法(FFT)
题目链接
\(\Huge\text{无图,慎入}\)
\(FFT\)即快速傅里叶变换,用于加速多项式乘法。
如果暴力做卷积的话就是一个多项式的每个单项式去乘另一个多项式然后加起来,时间复杂度为\(O(n^2)\)。
\(FFT\)算法基本思想是把系数表达式转换成点值表达式,求出卷积的点值表达式,再转换回系数表达式。
何为点值表达式?
把多项式看成一个函数,比如\(n\)次多项式\(F\)可以看成一个\(n\)次函数\(F(x)=a_0+a_1x+a_2x^2+\cdots +a_nx^n\)
众所周知,知道\(n\)次函数上\(n+1\)个点的坐标一定可以求出这个\(n\)次函数的解析式。
用我们学过的知识一次函数、二次函数都可以验证。
硬要扩展到任意次函数的话也好解释,可以得到\(n+1\)个方程,用高斯消元就能解出来,当然肯定不会在\(FFT\)算法里出现,因为算法的目的是加速。
系数表达式->点值表达式的过程叫\(DFT\),点值表达式->系数表达式的过程叫\(IDFT\)。
先讲\(DFT\)。
怎么系数->点值?
代\(n\)个点是最直接的办法,但显然求一个点的值就是\(O(n)\)的,总时间复杂度为\(O(n^2)\)
这里要引入单位根
概念,请确保了解复数的概念。
\(n\)次单位根记作\(\omega_n\),定义为\(n\)次方等于\(1\)的复数。
来推导一下性质。
首先\(n\)次方等于\(1\),这个复数的模长肯定是等于\(1\)的,所以在单位圆上。
其次,幅角\(\times n=2k\pi,k\in Z\)
脑补一下可以发现,\(n\)次单位根\(n\)等分单位圆,且\(1\)是一条等分线。
\(\omega_n^k\)表示从\(1\)开始逆时针旋转第\(i\)个\(n\)次单位根(从\(0\)开始)。
例如\(\omega_3^2\)就是把单位圆三等分,原点向正方向的射线是一条等分线,位于\(x\)轴下方的那条等分线。
不难发现,\(\forall n,\omega_n^0=1\) \(\forall n=2k,k\in Z, \omega_n^{\frac{n}{2}}=-1\)
理论上\(k\in [0,n)\),但类似于角度,也会出现超过\(360°\)或者负数的情况,同理也有\(\omega_n^k=\omega_n^{k\%n}\)
单位根的性质:
\(\omega_n^{a+b}=\omega_n^a\times \omega_n^b\),这个的解释就是复数相乘模长相乘幅角相加的法则。
有了这条,就能推出其他性质了。
\(\omega_n^k=\omega_{dn}^{dk}\)
\((\omega_n^k)^j=\omega_n^{jk}\)
步入正题:
\(DFT:\)
我们需要将\(n\)项多项式\(F(x)=a_0+a_1x+a_2x^2+\cdots +a_{n-1}x^{n-1}\)转成点值表达式。
假设\(n\)是\(2\)的正整数次幂。
设
\(FL(x)=a_0+a_2x+a_4x^2+\cdots+a_{n-2}x^{\frac{n}{2}-1}\)
\(FR(x)=a_1+a_3x+a_5x^2+\cdots+a_{n-1}x^{\frac{n}{2}-1}\)
易得\(F(x)=FL(x^2)+xFR(x^2)\)(自己代进去算一边就行了)
用\(\omega_n^k\)(\(k<\frac{n}{2}\))代入这个式子
\(\begin{align}F(\omega_n^k)&=FL(\omega_n^{2k})+\omega_n^kFR(\omega_n^{2k})\\&=FL(\omega_{\frac{n}{2}}^{k})+\omega_n^kFR(\omega_{\frac{n}{2}}^{k})\end{align}\)
这是\(k<\frac{n}{2}\)的情况,那如果\(k>=\frac{n}{2}\)呢?
\(\begin{align}F(\omega_n^{k+\frac{n}{2}})&=FL(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}FR(\omega_n^{2k+n})\\&=FL(\omega_{\frac{n}{2}}^{k})-\omega_n^kFR(\omega_{\frac{n}{2}}^{k})\end{align}\)
\(P.S:\omega_n^{k+\frac{n}{2}}=\omega_n^k\times \omega_n^{\frac{n}{2}}=-\omega_n^k\)
所以,如果我们知道了\(FL(\omega_{\frac{n}{2}}^{k})和FR(\omega_{\frac{n}{2}}^{k})\),就能求出\(F(\omega_n^k)\)和\(F(\omega_n^{k+\frac{n}{2}})\)
但是,我们怎么知道\(FL(\omega_{\frac{n}{2}}^{k})和FR(\omega_{\frac{n}{2}}^{k})\)呢?
难道这不像归并排序吗,分治啊!
实现过程中可能遇到的问题\(FAQ\)
1、怎么求\(\omega_n^k\)
只需要求出\(\omega_n^1\)他的\(k\)次方即\(\omega_n^k\)
2、怎么求\(\omega_n^1\)
\(\omega_n^1\)与\(1\)的夹角我们知道:\(\frac{2\pi}{n}\),然后又在单位圆上,解三角形啊!
告诉你结论就是\(\omega_n^1=cos(\frac{2\pi}{n})+sin(\frac{2\pi}{n})i\)
3、\(FL,FR\)数组从哪来
如果在递归中定义这两个数组显然会炸空间,于是蝴蝶操作诞生了。
咕咕咕。
但是,实际中\(n\)不一定是\(2\)的正整数次幂,我们只需要在最高位补\(0\)就行了,不影响结果。
现在我们求出了多项式的点值表达式,但这和他们的卷积有什么关系呢?
设\(H=F\times G\),\(H,F,G\)均为多项式
则\(H(x)=F(x)\times G(x)\)
所以如果我们知道两个多项式在\(x=\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\)的值,就能求出他们的卷积在\(x=\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\)的值,于是,第一步完成。
\(IDFT:\)
咕咕咕。
#include <cstdio>
#include <cmath>
#include <algorithm>
#define re register
using namespace std;
const int MAXN = 3000010;
const double PI = M_PI;
struct complex{
double x, y;
complex(double xx = 0, double yy = 0){ x = xx; y = yy; }
}a[MAXN], b[MAXN];
inline complex operator + (complex a, complex b){
return complex(a.x + b.x, a.y + b.y);
}
inline complex operator - (complex a, complex b){
return complex(a.x - b.x, a.y - b.y);
}
inline complex operator * (complex a, complex b){
return complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
inline int read(){
re int s = 0, w = 1;
re char ch = getchar();
while(ch < '0' || ch > '9'){ ch = getchar(); if(ch == '-') w = -1; }
while(ch >= '0' && ch <= '9'){ s = s * 10 + ch - '0'; ch = getchar(); }
return s * w;
}
int r[MAXN], n, m;
void FFT(complex *f, int mode){
for(re int i = 0; i < n; ++i) if(i < r[i]) swap(f[i], f[r[i]]);
for(re int p = 2; p <= n; p <<= 1){
re int len = p >> 1;
re complex tmp(cos(PI / len), mode * sin(PI / len));
for(re int l = 0; l < n; l += p){
re complex w(1, 0);
for(re int k = l; k < l + len; ++k){
re complex t = w * f[len + k];
f[len + k] = f[k] - t;
f[k] = f[k] + t;
w = w * tmp;
}
}
}
}
inline double d(double x){
if(fabs(x) < 1e-9) return 0;
return x;
}
int main(){
n = read(); m = read();
for(re int i = 0; i <= n; ++i) a[i].x = read();
for(re int i = 0; i <= m; ++i) b[i].x = read();
for(m += n, n = 1; n <= m; n <<= 1);
for(re int i = 1; i < n; ++i) r[i] = r[i >> 1] >> 1 | ((i & 1) * (n >> 1));
FFT(a, 1); FFT(b, 1);
for(re int i = 0; i < n; ++i) a[i] = a[i] * b[i];
FFT(a, -1);
for(int i = 0; i <= m; ++i) printf("%.0f ", d(a[i].x / n));
return 0;
}