FFT&NTT学习笔记
前置芝士:单位根
复数
定义
众所周知实数分布在一维的实数轴上,单位是1。类比实数轴,我们有虚数轴,单位是\(i\)。\(i\)是什么呢?简单地说就是\(\sqrt{-1}\)。类比于平面直角坐标系的x,y轴,我们有复平面,竖轴是虚数轴,横轴是实数轴,两个轴互相垂直。类比于平面直角坐标系上的每个坐标,复数轴上的每个点就是一个复数。平面直角坐标系的点\((x,y)\),对应的复数就是\(a+bi\)。其中称\(a\)为实部,\(b\)为虚部。
复数的模
即复数到原点的距离,\(|a+bi|=\sqrt{a^2+b^2}\)
复数的辐角
图中的\(θ\)。复数有无限多个辐角,一般取\([-\pi,\pi]\)之间那个。
复数的运算
复数的加减法满足平行四边形法则;复数的乘法为:模相乘,辐角相加。对应到代数也很简单:
注意\(i^2=-1\),可以推出上面的乘法公式。
欧拉公式
\(e\)的定义式
有人这么想过:如果我往银行里存了笔钱,我吃一期利息,再取出来,把利息加本金作为新的本金,再存进去,再吃一期利息,再取出来,再存进去......我的钱会不会无限多了呢?然而经过实验得出,随着存取的次数增多,现在拥有的钱和最初的本金的比值趋近于一个数,大概是二点七几。于是就有了\(e\)的定义式:
欧拉公式
可以发现\(e\)有这样的性质:
类似地可以发现:
归纳证明得到:
对应到复数可以得到:
也就是无限多个模为\(\sqrt{1^2+(\frac{x}{n})^2}=1\),并且辐角为\(\arctan{\frac{x}{n}}=\frac{x}{n}\)(注意这里的\(n\)为正无穷)的复数乘起来。根据复数乘法的定义,可以知道\(e^{ix}\)是个复数,模长为\(1\),辐角为\(x\)。那么把这个复数表示出来就是:\(\cos{x}+i·\sin{x}\)。
于是就有了欧拉公式:\(e^{ix}=\cos{x}+i·\sin{x}\)
单位根
定义
在复平面上做单位圆,以原点为起点,圆的\(n\)等分点为终点,做\(n\)个向量。其中辐角为正且最小的一个向量所对应的复数叫做\(n\)次单位根,记为\(w_n\)
根据复数乘法,圆上剩下的\(n-1\)个向量所对应的复数就是:\(w_n^2,w_n^3...w_n^n\)
易知\(w_n^k\)的辐角为\(\frac{2\pi}{n}\times k\),模为\(1\),那么根据欧拉公式:
性质
1.\(w_n^k=\cos{(k\times \frac{2\pi}{n})}+i·\sin{(k\times \frac{2\pi}{n})}\)
2.\(w_{2n}^{2k}=\cos{(2k\times \frac{2\pi}{2n})}+i·\sin{(2k\times \frac{2\pi}{2n})}=w_n^k\)
3.\(w_n^{k+\frac{n}{2}}=w_n^k\times w_n^{\frac{n}{2}}=w_n^k\times (\cos{\pi}+i·\sin{\pi})=-w_n^k\)
4.\(w_n^0=w_n^n=1\)
快速傅里叶变换FFT
多项式
定义
形如\(A(x)=\sum_{i=0}^{n}a_ix^i\)的\(A(x)\)称为多项式。
系数表示法
\(n+1\)个系数唯一确定一个n次多项式,所以可以用系数来表示这个多项式:\({\{}a_0,a_1,a_2,...,a_n{\}}\)
点值表示法
给\(n\)次多项式代\(n+1\)个不同的\(x\),可以得到\(n+1\)个不同的值\({\{}y_0,y_1,y_2,...,y_n{\}}\),如果这\(n+1\)个点\((x_0,y_0),(x_1,y_1),...,(x_n,y_n)\)线性无关,则这个多项式可以被这些点唯一确定。所以可以用点值来表示这个多项式。
快速傅里叶变换FFT
作用
已知一个多项式的所有系数,FFT可以在\(O(nlogn)\)的复杂度内得到一组点值,而朴素算法需要\(O(n^2)\)。对应的有快速傅里叶逆变换IDFFT,已知一个多项式的点指表示,可以在\(O(nlogn)\)的复杂度内求出多项式的系数。其中\(n\)是带入的\(x\)的数量。
一个\(n\)次多项式和一个\(m\)次多项式相乘会得到一个\(n+m\)次多项式,所以如果代入大于\(n+m+1\)个\(x\),把算出的每个\(A(x)\)和\(B(x)\)乘起来,得到\(n+m+1\)个点值\((x_0,A(x_0)\times B(x_0))\),\((x_1,A(x_1)\times B(x_1))\),...,\((x_{n+m},A(x_{n+m})\times B(x_{n+m}))\)可以唯一确定多项式\(A\times B\),再用IDFFT可以求出\(A\times B\)的系数。所以FFT可以用来计算多项式乘法,复杂度为\(O(nlogn)\),朴素算法需要\(O(n^2)\)。
公式推导
设多项式\(A(x)\)系数为\({\{}a_0,a_1,a_2,...,a_{n-1}{\}}\)。这里认为\(n\)可以表示为\(2^k\)的形式。实际的\(n\)若不足可以在后面补\(0\)。
设\(A_1(x)=(a_0+a_2x+a_4x^2+...+a_{n-2}x^{\frac{n}{2}-1})\);
设\(A_2(x)=(a_1+a_3x+a_5x^2+...+a_{n-1}x^{\frac{n}{2}-1})\);
则\(A(x)=A_1(x^2)+xA_2(x^2)\)。
设单位根为\(w_n\),代入\(w_n^k(k<\frac{n}{2})\)得:
再把\(w_n^{k+\frac{n}{2}}\)代入得
上面有点乱,我来整理一下:
可以发现只有中间的符号不同。又因为当\(k\)取遍\([0,\frac{n}{2})\)所有值时,\(k+\frac{n}{2}\)取遍\([\frac{n}{2},n)\)所有值。所以我们只需要计算当\(k\in [0,\frac{n}{2})\)时,长度为\(\frac{n}{2}\)的多项式\(A_1(w_{\frac{n}{2}}^{k})\)和\(A_2(w_{\frac{n}{2}}^{k})\)的值,就可以得到当\(k\in [0,n)\)时,长度为\(n\)的多项式\(A(w_{n}^{k})\)的值。递归计算即可,复杂度为\(O(nlogn)\)。
快速傅里叶逆变换IDFFT
作用
前面提到过,把点值表示转系数表示
公式推导
设\(y_0,y_1,...y_{n-1}\)为多项式\(a_0+a_1x+...+a_{n-1}x^{n-1}\)的点值表示。
设\(c_0\),\(c_1\),...,\(c_{n-1}\)满足\(c_k=\sum_{i=0}^{n-1}y_i(w_n^{-k})^i\),即多项式\(B(x)=y_0+y_1x+...+y_{n-1}x^{n-1}\)在\(w_n^0\),\(w_n^{-1}\),...,\(w_n^{-n+1}\)处的点值表示。
注意一下后面这个\(\sum\),设\(S(x)=\sum_{i=0}^{n-1}x^i\),代入\(w_n^k\)得
1.当\(k\neq0\)时,\(w_n^kS(w_n^k)=w_n^k +(w_n^k)^2+...+(w_n^k)^n\),相减得
2.当\(k=0\)时,\(S(w_n^k)=n\)
然后我们回到之前的式子\(c_k=\sum_{j=0}^{n-1}a_j \sum_{i=0}^{n-1}(w_n^{j-k})^i\),根据上面的结论可以知道,只有当\(j=k\)时,\(\sum_{i=0}^{n-1}(w_n^{j-k})^i=n\),否则等于\(0\)。
\(\therefore c_j=na_j,a_j=\frac{c_j}{n}\)
所以对于一个多项式\(a_0+a_1x+...+a_{n-1}x^{n-1}\),如果我们知道它的点值表示\(y_0\),\(y_1\),...,\(y_{n-1}\),就可以用FFT求出多项式\(B(x)=y_0+y_1x+...+y_{n-1}x^{n-1}\)在\(w_n^0\),\(w_n^{-1}\),...,\(w_n^{-n+1}\)处的点值表示\(c_0\),\(c_1\),...,\(c_{n-1}\),从而求出\(a_0\),\(a_1\),...,\(a_{n-1}\)
时间复杂度也是\(O(nlogn)\)
蝴蝶优化
你真的用递归去写?可以发现递归的方式每次都需要把\(A(x)\)的系数复制一遍,排个顺序得到\(A_1(x)\)和\(A_2(x)\)的系数。这是非常慢的。蝴蝶优化可以解决这个问题。
举个\(n=8\)的例子,我在这里列出每次递归的系数的顺序:
然后把最上面一排和最下面一排的下标的二进制写出来:
可以发现最后的顺序的二进制,就是对应位置的最初顺序的二进制的翻转。所以我们已开始就给它翻好,就不用递归了。给出代码:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i,a,b) for(rg int i=a;i<=b;++i)
using std::swap;
inline int read(){
rg int x(0),f(1); rg char c(gc);
while(c<'0'||'9'<c){ if(c=='-') f=-1; c=gc; }
while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=gc;
return x*f;
}
#define maxn 10000010
const double pi=acos(-1.0);
int n,m,limit=1,l,r[maxn];
double Cos[maxn],Sin[maxn];
struct complex{ double x,y; }a[maxn],b[maxn];//复数
complex operator+(cn complex &x,cn complex &y){ return (complex){x.x+y.x,x.y+y.y}; }
complex operator-(cn complex &x,cn complex &y){ return (complex){x.x-y.x,x.y-y.y}; }
complex operator*(cn complex &x,cn complex &y){ return (complex){x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x}; }
inline void FastFourierTransform(complex *a,cn int &type){
fp(i,0,limit-1) if(i<r[i]) swap(a[i],a[r[i]]);//蝴蝶优化
for(rg int mid=1;mid<limit;mid<<=1){
rg int len=mid<<1; complex Wn=(complex){Cos[len],type*Sin[len]};//根据上面的推导,这个type很灵性
for(rg int j=0;j<limit;j+=len){
complex Pow=(complex){1,0};
for(rg int k=0;k<mid;++k,Pow=Pow*Wn){
complex x=a[j+k],y=Pow*a[j+mid+k];
a[j+k]=x+y,a[j+mid+k]=x-y;
}
}
}
}
int main(){
n=read(),m=read(); fp(i,0,n) a[i].x=read(); fp(i,0,m) b[i].x=read();
while(limit<=n+m) limit<<=1,++l; fp(i,0,limit-1) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));//求出翻转的下标
fp(i,0,limit) Cos[i]=cos(pi*2/i),Sin[i]=sin(pi*2/i);
FastFourierTransform(a,1),FastFourierTransform(b,1);
fp(i,0,limit) a[i]=a[i]*b[i]; FastFourierTransform(a,-1);
fp(i,0,n+m) printf("%d ",(int)(a[i].x/limit+0.5)); return 0;
}
快速数论变换NTT
可以发现单位根有的性质原根都有......所以可以用原根代替单位根。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define rg register
#define il inline
#define LL long long
#define cn const
#define gc getchar()
#define fp(i,a,b) for(rg int i=a;i<=b;++i)
using std::swap;
il int read(){
rg int x(0),f(1); rg char c(gc);
while(c<'0'||'9'<c){ if(c=='-') f=-1; c=gc; }
while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=gc;
return x*f;
}
#define maxn 3000010
const int G=3,invG=332748118,P=998244353;//G是原根,invG是原根的逆元
int n,m,limit=1,l,r[maxn];
LL a[maxn],b[maxn];
inline LL FastPow(LL a,int b){
LL ans=1;
for(;b;b>>=1,a=a*a%P) if(b&1) ans=ans*a%P;
return ans;
}
inline void NumberTheoreticTransform(LL *a,cn int &type){
fp(i,0,limit-1) if(i<r[i]) swap(a[i],a[r[i]]);
for(rg int mid=1;mid<limit;mid<<=1){
rg int len=mid<<1; rg LL Gn=FastPow(type?G:invG,(P-1)/len);
for(rg int j=0;j<limit;j+=len){
LL Pow=1;
for(rg int k=0;k<mid;++k,Pow=Pow*Gn%P){
LL x=a[j+k],y=Pow*a[j+mid+k]%P;
a[j+k]=(x+y)%P,a[j+mid+k]=(x-y+P)%P;
}
}
}
}
int main(){
n=read(),m=read(); fp(i,0,n) a[i]=(read()+P)%P; fp(i,0,m) b[i]=(read()+P)%P;
while(limit<=n+m) limit<<=1,++l; fp(i,0,limit-1) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
NumberTheoreticTransform(a,1),NumberTheoreticTransform(b,1);
fp(i,0,limit) a[i]=a[i]*b[i]%P; NumberTheoreticTransform(a,0);
rg LL invN=FastPow(limit,P-2); fp(i,0,n+m) printf("%d ",a[i]*invN%P);
return 0;
}
模板题:【模板】多项式乘法(FFT)