FFT学习笔记
零、前置知识
复数
考虑最简单的无解二次方程:\(x^2+1=0\)。现在,思考一个问题:它为什么无解?
其实,有一种可能,它是有解的,但这个解不在我们所知的范围内,即不在 \(\mathbb R\) 中。
现在,我们规定:该方程的一个根为 \(i\),称为虚数单位。(一件显然的事是另一根为 \(-i\))
此外,我们规定原先的四则运算规则不变。类似于实数中出现 \(1\),我们可以得到一个新的数系。
可以发现,任意一元二次方程都有解。如果将 \(i\) 记作 \(\sqrt{-1}\),那么原先的求根公式也适用于目前。
此时,所有数都可以表示为 \(a+bi (a,b\in \mathbb R)\) 的形式,称为复数,复数集合记为 \(\mathbb C\)
然后稍微列一下运算法则:
复平面
就是 $$x+yi\rightarrow \vec{a}=(x,y)$$
定义复数 \(z=a+bi\) 的模长 \(|z|=\sqrt{a^2+b^2}\)
可以由几何意义看出,任意复数 \(a+bi\) 可以唯一写成 \(r(\cos x+i\sin x)=r\cdot e^{ix}\) 的形式。
加法的几何意义显然就是向量相加。
乘法的话
复数相乘,模相乘,幅角相加。
一、理论部分
0. 什么是多项式乘法
众所周知,多项式本质是一种特殊的函数,可以表示为自变量的若干次幂之和,即
其中 \(c_i\) 被称为 \(x^i\) 的系数。
已知 \(F,G\) 是两个多项式函数,考虑定义一个新的函数 \(H(x)=F(x)G(x)\)。我们称 \(H\) 为 \(F\) 与 \(G\) 的乘积,记作 \(H=F\cdot G\),这种通过已知的两个多项式函数,以乘积形式生成另一个函数的运算,称为多项式乘法。
本文讨论的就是计算多项式乘法,即求出 \(H\) 解析式的过程。
根据定义,有
我们发现,两个多项式的乘积仍是多项式。
由于多项式的特性,我们只需知道该多项式的每一项系数即可,即使用一个\(n\) 维向量去描述一个\(n-1\)次多项式。
显然,有一种暴力的方法 \(H[i]=\sum_{j=0}^iF[j]* G[i-j]\)。我们称这种形如\(c_n=\sum_{i\otimes j=n}a_i\cdot b_j\)运算为卷积。(其实只是卷积的一种,即加法卷积)
显然,这样做复杂度是 \(\mathcal{O}(n^2)\)的,实在太慢了。
\(\frac{1}{2}\). 多项式的点值表示法
既然多项式是一个函数,那我么可以画出它的函数图象,并在上面取几个点。
显然,由待定系数法可知,给定 \(n+1\) 个不同的点,可以唯一确定一个次数不超过 \(n\) 的多项式。
我们可以考虑用点值反推解析式,即如果已知函数 \(F,G\) 的\(n\) 个对应点,由定义式 \(H(x)=F(x)G(x)\) 可以直接 \(\mathcal O(n)\) 算出 \(H\) 对应的点!
但如果随意带入点,用待定系数法高斯消元暴力反推解析式,复杂度 \(\mathcal O(n^3)\),直接爆炸。
当然我们可以用拉格朗日插值公式做到 \(\mathcal O(n^2)\),然并卵,还不如暴力乘法……
\(\frac{9}{10}\). 单位根的引入及其性质
既然带一些普通的东西进去没啥用,那我们直接躺平带一些奇怪的东西进去,比如复数。考虑带入 \(n\) 次单位根。
定义:方程 \(x^n-1=0\) 在复数域的全部解称为 \(n\) 次单位根。由代数基本定理知,\(n\) 次单位根有 \(n\) 个,分别记为 \(\omega_n^0 \cdots \omega_n^{n-1}\) 。由复数乘法模相乘,幅角相加可知,\(n\) 次单位根全部在单位圆上,且幅角为 \(\frac{2\pi}{n}\) 的整数倍。显然,\(\omega_n^k=\cos \frac{2\pi k}{n}+i\sin \frac{2\pi k}{n}\)。当然,其它单位根是 \(\omega_n^1\) 的整数次幂,因此我们称其为主\(n\)次单位根,简记为 \(\omega_n\)。
根据欧拉公式 \(e^{ix}=\cos x+i\sin x\),有\(\omega_n^k=e^{\frac{2\pi k}{n}i}\)
一些性质:
-
\(\omega_n^k=(\omega_n^1)^k\)
-
\(\omega_n^i\cdot\omega_n^j=\omega_n^{i+j}\)
推论:\((\omega_n^k)^d=\omega_n^{kd}\)
- \(\omega_{dn}^{dk}=\omega_n^k\) (消去引理)
证明:\(\large\omega_{dn}^{dk}=e^{\frac{2\pi dk}{dn}i}=e^{\frac{2\pi k}{n}i}=\omega_n^k\)
- 若 \(2|n\),\(\omega_n^{k+\frac{n}{2}}=-\omega_n^k\)
证明:\(\omega_n^{k+\frac{n}{2}}=\omega_n^{k}\omega_n^{\frac{n}{2}}=\omega^1_2\omega_n^k=-\omega_n^k\)
推论:\((\omega_n^k)^2=(\omega_n^{k+\frac{n}{2}})^2=\omega_{n/2}^k\)(折半引理)
- 对于任意不为 \(n\) 倍数的 \(k\),有:
证明:
知道这些结论,就可以学习 FFT 了!
1. DFT
考虑求出一个多项式在 \(\omega_n^0\cdots\omega_n^{n-1}\) 的值,我们称这种运算为 DFT。
首先,我们将该多项式的次数扩充为 \(2^k-1\)(此时一共有 \(2^k\) 项)。
将 \(A(x)\) 写成系数向量形式 \(a=[a_0,a_1,a_2,\cdots,a_{n-1}]\)
考虑将其按奇数项和偶数项分为
分别记其对应的多项式为 \(A^{[0]}(x)\) 和 \(A^{[1]}(x)\)。
换成 \(x^2\):
\(A^{[1]}\) 乘上 \(x\):
由此得到 \(A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)\)
说人话:如果知道 \(A^{[0]}(x^2)\) 和 \(A^{[1]}(x^2)\) ,就可以求出 \(A(x)\)。
问题转化为求出所有 \(A^{[0]}((\omega_n^k)^2)\) 和 \(A^{[1]}((\omega_n^k)^2)\)。
由折半引理知,求出所有 \(A^{[0]}((\omega_n^k)^2)\) 和 \(A^{[1]}((\omega_n^k)^2)\) 就是求出所有 \(A^{[0]}(\omega_{n/2}^k)\) 和 \(A^{[1]}(\omega_{n/2}^k)\)。
注意到这是原问题的子问题,可以直接分治处理。递归边界为 \(n=1\),此时对应的多项式为常量,直接返回即可。
考虑合并。
\(A(\omega_n^k)=A^{[0]}(\omega_{n/2}^k)+\omega_n^kA^{[1]}(\omega_{n/2}^k)\)
\(A(\omega_n^{k+n/2})=A^{[0]}(\omega_{n/2}^k)-\omega_n^kA^{[1]}(\omega_{n/2}^k)\)
显然,合并复杂度为 \(\mathcal O(n)\)。
则 FFT 的复杂度 \(T(n)=2T(\frac{n}{2})+\mathcal O(n)\),故 \(T(n)=\mathcal O(n\log n)\)
全剧终。
我们发现,我们得到的只是一些点值罢了,根本不是系数表示。
2. IDFT
即 DFT 的逆运算。
如何将点值转换为系数?注意到上文提到 DFT 时,说的是将 \(A(x)\) 看成系数向量,即一维矩阵。
如果我们将得到的点值记为一维向量 \(b\),则有:\(a=IDFT(b)\),即 \(b=DFT(a)\)。
我们现在知道 \(b\) 求 \(a\)。
由定义知 \(b[k]=\sum_{i=0}^{n-1}(\omega_n^k)^i\cdot a[i]\)
如果你会单位根反演,你可能能够推出 \(a\) 关于 \(b\) 的式子。这里给出单位根反演的原理。
还记得求和引理吗?
对于任意不为 \(n\) 倍数的 \(k\),有:
此外,若 \(k\) 为 \(n\) 的倍数,显然上个式子的值为 \(n\)。
即
考虑如下式子:
感觉是句废话……
考虑把 \([k=i]\) 往 \([n|k]\) 上靠。
我们发现,\(i,k\) 的取值范围是 \([0,n-1]\),即 \(k=k\bmod n,i=i\bmod n\)
那么,\([k=i]\) 可以写成 $[k\bmod n=i\bmod n] $,即 $ [k\equiv i\pmod n]$ 或 \([n|(k-i)]\)
如果直接将 \([n|(k-i)]\) 带入,可能无法得到想要的结果。
但 \([n|(k-i)]\) 与 \([n|(i-k)]\) 是等价的,因此尝试带入 \([n|(i-k)]\)。
交换和式:
换个下标
这次成功了。变形一下:
我们发现右边和 DFT 过程几乎完全一致,只是将 \(\omega_n^k\) 改为 \(\omega_n^{-k}\)。
但是,有一个细节:\(n+1\) 个点只能确定一个 \(n\) 次多项式,因此点数应足够多,以完成最后的 IDFT。
到这里,FFT 的理论部分就讲完了。
二、代码实现
1. 递归版本(无卡常)
定义一个常量:
const double pi=acos(-1);
首先实现一个复数类:
struct cp{
double a,b;
cp(double aa=0,double bb=0){a=aa;b=bb;}
cp operator +(const cp &oth)const{return cp(a+oth.a,b+oth.b);}
cp operator -(const cp &oth)const{return cp(a-oth.a,b-oth.b);}
cp operator *(const cp &oth)const{return cp(a*oth.a-b*oth.b,a*oth.b+b*oth.a);}
};
完整代码:
#include <cstdio>
#include <cmath>
using namespace std;
const int N=1024*2048+5;
const double pi=acos(-1);
struct cp{
double a,b;
cp(double aa=0,double bb=0){a=aa;b=bb;}
cp operator +(const cp &oth)const{return cp(a+oth.a,b+oth.b);}
cp operator -(const cp &oth)const{return cp(a-oth.a,b-oth.b);}
cp operator *(const cp &oth)const{return cp(a*oth.a-b*oth.b,a*oth.b+b*oth.a);}
}f[N],g[N],save[N];
void fft(cp *f,int len,int flag){
if(len==1) return ;
cp *f0=f,*f1=f+len/2;
for(int i=0;i<len;++i) save[i]=f[i];
for(int i=0;i<len/2;++i) f0[i]=save[i<<1],f1[i]=save[i<<1|1];
fft(f0,len/2,flag);fft(f1,len/2,flag);
cp Wn(cos(2*pi/len),flag*sin(2*pi/len)),W(1,0);
for(int i=0;i<len/2;++i,W=W*Wn){
save[i]=f0[i]+W*f1[i];
save[i+(len>>1)]=f0[i]-W*f1[i];
}
for(int i=0;i<len;++i) f[i]=save[i];
}
int main(){
int n,m;scanf("%d%d",&n,&m);
for(int i=0;i<=n;++i) scanf("%lf",&f[i].a);
for(int i=0;i<=m;++i) scanf("%lf",&g[i].a);
for(m+=n,n=1;n<=m;n<<=1);
fft(f,n,1);fft(g,n,1);
for(int i=0;i<n;++i) f[i]=f[i]*g[i];
fft(f,n,-1);
for(int i=0;i<=m;++i){
printf("%d ",(int)(f[i].a/n+0.5));
}
return 0;
}
一些细节:
- 数组不要开 \(2\) 的方幂。
- \(2^k\) 必须严格大于 \(n+m\)(\(n+m\) 是最高次项,而 \(2^k\) 位最多表示 \(2^k-1\) 次)
- 答案要取整。
且慢!如果你提交了上述代码,你会发现时间卡得很紧,最大点在 \(0.98s\) 左右,完全不能接受。
怎么办?卡常!
2. FFT 高效实现(位逆序置换)
在说位逆序置换之前,先来简单地卡卡常吧。
首先,去除构造函数。
考虑如下代码:
for(int i=0;i<len/2;++i,W=W*Wn){
save[i]=f0[i]+W*f1[i];
save[i+(len>>1)]=f0[i]-W*f1[i];
}
W*f1[i]
被算了两遍!实数乘法多慢大家都清楚,而复数乘法更慢!我们可以用一个变量记录一下。
for(int i=0;i<len/2;++i,W=W*Wn){
cp tmp=W*f1[i];
save[i]=f0[i]+tmp;
save[i+(len>>1)]=f0[i]-tmp;
}
接下来才是主角!
很容易发现,对效率影响最大的是分治过程中的数组拷贝。
为什么要拷贝数组?因为奇偶分开这一要求。
我们模拟一组数据看看。
0 1 2 3 4 5 6 7
0 2 4 6|1 3 5 7
0 4|2 6|1 5|3 7
我们考虑求出最后的的序列。
经过观察,我们发现在第 \(i\) 的数为 \(i\) 的二进制反转。
如何二进制反转?可以 \(\mathcal O(n)\) 求出。
for(int i=0;i<n;++i){
rev[i]=(rev[i>>1]>>1)|((i&1)?(n>>1):0);
}
大概意思是:右移一位后,再翻转,去除末尾零后将末位翻转至高位。
例子:
\(0101\rightarrow 0010\rightarrow 0100\rightarrow 010\rightarrow 1010\)
这就是位逆序置换(也有叫蝴蝶变换的。)
然后就可以迭代实现了。可以发现,这避免了所有数组拷贝。
#include <cstdio>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
#define GetC() ((p1==p2)&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
struct Ios{}io;
template <typename _tp>
Ios &operator >>(Ios &in,_tp &x){
x=0;int w=0;char c=GetC();
for(;!isdigit(c);w|=c=='-',c=GetC());
for(;isdigit(c);x=x*10+(c^'0'),c=GetC());
if(w) x=-x;
return in;
}
const int N=1024*2048+5;
const double pi=acos(-1);
struct cp{
double a,b;
cp operator +(const cp &oth)const{return (cp){a+oth.a,b+oth.b};}
cp operator -(const cp &oth)const{return (cp){a-oth.a,b-oth.b};}
cp operator *(const cp &oth)const{return (cp){a*oth.a-b*oth.b,a*oth.b+b*oth.a};}
}f[N],g[N];
int rev[N];
void fft(cp *f,int n,int type){
for(int i=0;i<n;++i) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int p=2;p<=n;p<<=1){
int len=p>>1;
cp Wn=(cp){cos(2*pi/p),type*sin(2*pi/p)};
for(int k=0;k<n;k+=p){
cp W=(cp){1,0};
for(int i=k;i<k+len;++i){
cp tmp=W*f[i+len];
f[i+len]=f[i]-tmp;
f[i]=f[i]+tmp;
W=W*Wn;
}
}
}
}
int main(){
int n,m;io>>n>>m;
for(int i=0;i<=n;++i){
int x;io>>x;f[i].a=x;
}
for(int i=0;i<=m;++i){
int x;io>>x;g[i].a=x;
}
for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;++i) rev[i]=(rev[i>>1]>>1)|((i&1)?n>>1:0);
fft(f,n,1);fft(g,n,1);
for(int i=0;i<n;++i) f[i]=f[i]*g[i];
fft(f,n,-1);
for(int i=0;i<=m;++i) printf("%d ",(int)(f[i].a/n+0.5));
return 0;
}
进行卡常后,最大时间变为 \(4.9s\),效率是原来的 \(2\) 倍!
一些常见错误:
- 最后结果忘 \(\div n\) 。
- 运算中的单位根是 \(\omega_p^k\) 而非 \(\omega_n^k\) !
- 位逆序置换中 \(i<rev_i\) 才交换!
附:三次变两次优化
在读入时,我们只用了系数的实部,浪费了虚部。
考虑 \(H(x)=\sum_{k=0} (a_k+i\cdot b_k)x^k=\sum_{k=0} a_kx^k+\sum_{k=0} i\cdot b_kx^k\)
记 \(F(x)=\sum_{k=0} a_kx^k , G(x)=\sum_{k=0} b_kx^k\),则 \(H(x)=F(x)+i\cdot G(x)\)。
\(H^2(x)=F^2(x)-G^2(x)+2i\cdot F(x)G(x)\)
记 \(f(x)=F^2(x)-G^2(x) , g(x)=F(x)G(x)\)
\(H(x)=f(x)+2i\cdot g(x)=\sum_{k=0} (f_k+2i\cdot g_k)x^k\)
我们发现,\(g(x)\) 就是我们要求的东西,其各项系数为 \(H\) 的各项系数的虚部 \(\div 2\)。
然后只需要两次 FFT 即可。
#include <cstdio>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
#define GetC() ((p1==p2)&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
struct Ios{}io;
template <typename _tp>
Ios &operator >>(Ios &in,_tp &x){
x=0;int w=0;char c=GetC();
for(;!isdigit(c);w|=c=='-',c=GetC());
for(;isdigit(c);x=x*10+(c^'0'),c=GetC());
if(w) x=-x;
return in;
}
const double pi=acos(-1);
const int N=1024*2048+5;
struct cp{
double a,b;
cp operator +(const cp &oth)const{return (cp){a+oth.a,b+oth.b};}
cp operator -(const cp &oth)const{return (cp){a-oth.a,b-oth.b};}
cp operator *(const cp &oth)const{return (cp){a*oth.a-b*oth.b,a*oth.b+b*oth.a};}
}f[N];
int rev[N];
void fft(cp *f,int n,int type){
for(int i=0;i<n;++i) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int p=2;p<=n;p<<=1){
int len=p>>1;
cp Wn=(cp){cos(2*pi/p),type*sin(2*pi/p)};
for(int k=0;k<n;k+=p){
cp W=(cp){1,0};
for(int i=k;i<k+len;++i){
cp tmp=W*f[i+len];
f[i+len]=f[i]-tmp;
f[i]=f[i]+tmp;
W=W*Wn;
}
}
}
}
int main(){
int n,m;io>>n>>m;
for(int i=0;i<=n;++i){
int x;io>>x;f[i].a=x;
}
for(int i=0;i<=m;++i){
int x;io>>x;f[i].b=x;
}
for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;++i) rev[i]=(rev[i>>1]>>1)|((i&1)?n>>1:0);
fft(f,n,1);
for(int i=0;i<n;++i) f[i]=f[i]*f[i];
fft(f,n,-1);
for(int i=0;i<=m;++i){
printf("%d ",(int)(f[i].b/(2*n)+0.5));
}
return 0;
}
此时最大点变为 \(0.33s\)。
到这里,本文就结束了,想了解更多多项式科技的话,可以去学习 NTT。