【模板】多项式乘法(FFT)

link

我竟然搞懂了FFT。好感动。说不定我的数学有救了呢。

首先它的思想就是求出 \(2\times m\) 个值,分别是 \(f(\omega_m^0),f(\omega_m^1)\dots f(\omega_m^{m-1})\) ,以及 \(g(\omega_m^0),g(\omega_m^1)\dots g(\omega_m^{m-1})\) 的值。由于满足 \(\forall i\in R,f(i)\times g(i)=r(i)\) ,所以有了这些值我们就可以求出 \(r\) 函数上的 \(2\times m\) 个点,于是就可以唯一地确定这个函数。FFT是在优化求值和确定函数这两个过程以使得程序复杂度降低到 \(O(N\log N)\) 的。

首先 \(\omega_n^k\) 满足一些性质。

\[\omega_n^k=\cos\frac{2\pi k}{n}+\sin\frac{2\pi k}{n}\times i \]

由于单位根可以看成是以原点为圆心的一个圆上点的集合,所以求得夹角 (\(\frac{2\pi k}{n}\))和半径 (1) 之后就可以知道这个三角形的横纵坐标,也就可以对应唯一的一个复数了,

\[\omega_n^k=\omega_{2n}^{2k} \]

用几何法会更好理解一点。大不了把圆再均分一次。

\[\omega_n^{kn}=1,k\in Z \]

也挺好理解,圆上第一个单位根是在 \(1+0i\) 的位置。

理解了这些就可以考虑进行FFT了。首先求点值。

\[A(x)=a_0+a_1x+a_2x^2+\dots+a_{n-2}x^{n-2}+a_{n-1}x^{n-1} \]

进行奇偶分类。创造两个新函数:

\[A1(x)=a_0+a_2x+a_4x^2+\dots+a_{n-2}x^{n-2}\\A2(x)=a_1+a_3x+a_5x^2+\dots+a_{n-1}x^{n-1} \]

会发现

\[A(x)=A1(x^2)+xA2(x^2) \]

而单位根满足一些绝妙的性质,上文提到 \(\omega_n^k=\omega_{2n}^{2k}\) ,所以假如当前求解的是 \(A(\omega_{n}^k)\) ,那么有:

\[A(\omega_n^k)=A1(\omega_{n}^{2k})+\omega_n^kA2(\omega_{n}^{2k})=A1(\omega_{n/2}^{k})+\omega_n^kA2(\omega_{n/2}^{k}) \]

所以每一次要求的自变量的值上标是不变的,变的只有下标,于是递归的时候把逐层减半的 n 代成一个参数就可以了。而n恰好和序列长度的变化方式一样,所以带一个参数就可以了。

好了,我们完成了 \(O(N\log N)\) 求出 \(2N\) 个点值,接下来考虑怎么把这些点值还原成为一个多项式。假设结果多项式是 \(C\)

\[\begin{bmatrix}1&1&\dots&1\\1&\omega_n^1&\dots&\omega_n^{n-1}\\\dots&\dots&\dots&\dots\\1&\omega_n^{n-1}&\dots&\omega_n^{(n-1)^2}\end{bmatrix}\times\begin{bmatrix}c_0\\c_1\\\dots\\c_{n-1}\end{bmatrix}=\begin{bmatrix}f(\omega_n^0)\\f(\omega_n^1)\\\dots\\f(\omega_n^{n-1})\end{bmatrix} \]

可以看成是 \(B\times C=A\) ,所以我们要求的矩阵 \(C=A\times B^{-1}\) 。有结论是:

\[B^{-1}=\begin{bmatrix}1&1&\dots&1\\1&\omega_n^{-1}&\dots&\omega_n^{1-n}\\\dots&\dots&\dots&\dots\\1&\omega_n^{1-n}&\dots&\omega_n^{-(1-n)^2}\end{bmatrix} \]

鬼知道谁想出来的。反正会发现最后把点值转化成参数的过程也可以用相似的分治过程进行计算。

写法上,有朴素递归的写法,也从小到大进行合并,合并之前要把所有元素进行重新排列,递推式是 cnt[i]=(cnt[i>>1]>>1)+((i&1)<<scnt-1);。后者常数小一点,空间消耗也能小一点。差距能在30%左右。

有几个注意事项。复数尽量手写,C++自带的咱也不会用啊对吧。这个算法有精度问题,不能处理太大的数据,在取模场景下也不适用。

递归版:

#include<bits/stdc++.h>
//#define zczc
const int N=4000010;
const double Pi=acos(-1.0);
using namespace std;
inline void read(int &wh){
    wh=0;int f=1;char w=getchar();
    while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
    while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
    wh*=f;return;
}

struct node{double a,b;}a[N],b[N];
inline node operator +(node s1,node s2){return (node){s1.a+s2.a,s1.b+s2.b};}
inline node operator -(node s1,node s2){return (node){s1.a-s2.a,s1.b-s2.b};}
inline node operator *(node s1,node s2){return (node){s1.a*s2.a-s1.b*s2.b,s1.a*s2.b+s1.b*s2.a};}

void FFT(node a[],int limit,int type){
	if(limit==1)return;
	node a1[limit>>1],a2[limit>>1];
	int len=limit>>1;
	for(int i=0;i<limit;i+=2){
		a1[i>>1]=a[i],a2[i>>1]=a[i+1];
	}
	FFT(a1,len,type);FFT(a2,len,type);
	node ww=(node){cos(2.0*Pi/limit),type*sin(2.0*Pi/limit)};
	node w=(node){1.0,0};
	for(int i=0;i<len;i++,w=w*ww){
		a[i]=a1[i]+w*a2[i];
		a[i+len]=a1[i]-w*a2[i];
	}
}

signed main(){
	
	#ifdef zczc
	freopen("in.txt","r",stdin);
	#endif
	
	int m,n;read(m);read(n);
	for(int i=0;i<=m;i++)scanf("%lf",&a[i].a);
	for(int i=0;i<=n;i++)scanf("%lf",&b[i].a);
	int limit=1;
	while(limit<=m+n)limit<<=1;
	FFT(a,limit,1);
	FFT(b,limit,1);
	for(int i=0;i<limit;i++)a[i]=a[i]*b[i];
	
	
	
	FFT(a,limit,-1);
	for(int i=0;i<=m+n;i++)printf("%d ",(int)(a[i].a/limit+0.5));
	
	
	return 0;
}
#include<bits/stdc++.h>
//#define zczc
const int N=4000010;
const double Pi=acos(-1.0);
using namespace std;
inline void read(int &wh){
    wh=0;int f=1;char w=getchar();
    while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
    while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
    wh*=f;return;
}

struct node{double a,b;}a[N],b[N];
inline node operator +(node s1,node s2){return (node){s1.a+s2.a,s1.b+s2.b};}
inline node operator -(node s1,node s2){return (node){s1.a-s2.a,s1.b-s2.b};}
inline node operator *(node s1,node s2){return (node){s1.a*s2.a-s1.b*s2.b,s1.a*s2.b+s1.b*s2.a};}

int cnt[N];

node s[N];
void FFT(node a[],int limit,int type){
	for(int i=0;i<limit;i++)s[cnt[i]]=a[i];
	for(int len=1;len<limit;len<<=1){//序列长度的一半 
		node ww=(node){cos(Pi/len),type*sin(Pi/len)};
		for(int l=0;l+len<limit;l+=(len<<1)){//枚举序列左端点 
			node w=(node){1,0};int mid=l+len;
			for(int i=0;i<len;i++,w=w*ww){
				node A=s[l+i],B=s[mid+i];
				s[l+i]=A+w*B,s[mid+i]=A-w*B;
			}
		}
	}
	for(int i=0;i<limit;i++)a[i]=s[i];
}

signed main(){
	
	#ifdef zczc
	freopen("in.txt","r",stdin);
	#endif
	
	int m,n;read(m);read(n);
	for(int i=0;i<=m;i++)scanf("%lf",&a[i].a);
	for(int i=0;i<=n;i++)scanf("%lf",&b[i].a);
	int limit=1;
	while(limit<=m+n)limit<<=1;
	
	int scnt=0;
	while((1<<scnt)!=limit)scnt++;
	for(int i=0;i<limit;i++){
		cnt[i]=(cnt[i>>1]>>1)+((i&1)<<scnt-1);
		//printf("%d\n",cnt[i]);
	}
	
	
	FFT(a,limit,1);
	FFT(b,limit,1);
	for(int i=0;i<limit;i++)a[i]=a[i]*b[i];
	FFT(a,limit,-1);
	for(int i=0;i<=m+n;i++)printf("%d ",(int)(a[i].a/limit+0.5));
	
	
	return 0;
}
posted @ 2022-06-29 11:13  Feyn618  阅读(50)  评论(0编辑  收藏  举报