快速傅里叶变换复习笔记

  • .real()成员函数
  • FFT的本质是快速计算多项式的点值表示
  • 对负实数的四舍五入需要-0.5
  • 编写函数接收数组地址时,注意不能破坏原数组
  • FFT有较为严重的精度问题,double甚至难以准确计算两个\(10^9\)级别的整数相乘的结果,即使采用long double也时常无法得到准确的答案,这或许也是模板题中保证系数为个位数的原因
  • FFT要求项数严格为2的幂次,因此要多开一些空间
点击查看代码
#include <bits/stdc++.h>
using namespace std;
complex<long double>a[100005],b[100005],tmp[300005],ans[300005];
complex<long double>F[300005],G[300005];
const complex<long double> I(0,1);
const long double pi=acos(-1.0);
long long read1()
{
	char cc=getchar();
	while(!(cc>=48&&cc<=57))
	{
		if(cc=='-')
		{
			break;
		}
		cc=getchar();
	}
	bool f=false;
	long long s=0;
	if(cc=='-')
	{
		f=true;
	}
	else
	{
		s=cc-48;
	}
	while(1)
	{
		cc=getchar();
		if(cc>=48&&cc<=57)
		{
			s=s*10+cc-48;
		}
		else
		{
			break;
		}
	}
	if(f==true)
	{
		s=-s;
	}
	return s;
}
void FFT(complex<long double>*f,long long n,long long opt)
{
	if(n==1)
	{
		return;
	}
	for(long long i=0;i<n;i++)
	{
		tmp[i]=f[i];
	}
	for(long long i=0;i<n;i++)
	{
		if(i%2==0)
		{
			f[i/2]=tmp[i];
		}
		else
		{
			f[i/2+n/2]=tmp[i];
		}
	}
	complex<long double>*g=f,*h=f+n/2;
	FFT(g,n/2,opt);
	FFT(h,n/2,opt);
	complex<long double>cur(1,0),step=exp(I*(2*pi/n*opt));
	for(long long i=0;i<n/2;i++)
	{
		tmp[i]=g[i]+h[i]*cur;
		tmp[i+n/2]=g[i]-h[i]*cur;
		cur*=step;
	}
	for(long long i=0;i<n;i++)
	{
		f[i]=tmp[i];
	}
}
void mul(complex<long double>*f,long long n,complex<long double>*g,long long m,complex<long double>*ans,long long maxn)
{
	/*
	for(long long i=0;i<n;i++)
	{
		for(long long j=0;j<m;j++)
		{
			if(i+j<maxn)
			{
				ans[i+j]+=(f[i]*g[j]);
			}
		}
	}
	*/
	long long tmp=ceil(log(n+m+2)/log(2));
	for(long long i=0;i<(1<<tmp);i++)
	{
		F[i]=0;
		if(i<n)
		{
			F[i]=f[i];
		}
	}
	for(long long i=0;i<(1<<tmp);i++)
	{
		G[i]=0;
		if(i<m)
		{
			G[i]=g[i];
		}
	}
	FFT(F,(1<<tmp),1);
	FFT(G,(1<<tmp),1);
	for(long long i=0;i<(1<<tmp);i++)
	{
		F[i]*=G[i];
	}
	FFT(F,(1<<tmp),-1);
	for(long long i=0;i<maxn;i++)
	{
		ans[i]=F[i].real()/(1<<tmp);
	}
}
int main()
{
	int n,m;
	cin>>n>>m;
	for(int i=0;i<=n;i++)
	{
		a[i]=read1();
	}
	for(int i=0;i<=m;i++)
	{
		b[i]=read1();
	}
	mul(a,n+1,b,m+1,ans,n+m+1);
	for(int i=0;i<n+m+1;i++)
	{
		printf("%lld ",(long long)(ans[i].real()+0.5));
	}
	cout<<endl;
	return 0;
}
posted @ 2024-07-09 19:13  D06  阅读(1)  评论(0编辑  收藏  举报