FFT/NTT学习笔记

非常推荐这一篇,挺详细的
拉格朗日差值:https://www.zhihu.com/question/58333118

FFT: https://www.luogu.com.cn/blog/command-block/fft-xue-xi-bi-ji

NTT: https://www.luogu.com.cn/blog/command-block/ntt-yu-duo-xiang-shi-quan-jia-tong

原根表(NTT要用,还要背。。。):http://blog.miskcoo.com/2014/07/fft-prime-table

注意一个点:虚数的千万不要写错!!!

struct fu{
	fu(double xx=0,double yy=0){x=xx;y=yy;}
	double x,y;
	fu operator + (fu const & tmp){return fu(x+tmp.x,y+tmp.y);}
	fu operator - (fu const & tmp){return fu(x-tmp.x,y-tmp.y);}
	fu operator * (fu const & tmp){return fu(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x);}
	fu operator / (fu const & tmp){double t=tmp.x*tmp.x+tmp.y*tmp.y;return fu((x*tmp.x+y*tmp.y)/t,(y*tmp.x-x*tmp.y)/t);}
};

FFT最终代码:

#include<bits/stdc++.h>
#define N 7700000
using namespace std;
const double Pi=acos(-1);
int n,m,tr[N];
struct fu{
	fu(double xx=0,double yy=0){x=xx;y=yy;}
	double x,y;
	fu operator + (fu const & tmp){return fu(x+tmp.x,y+tmp.y);}
	fu operator - (fu const & tmp){return fu(x-tmp.x,y-tmp.y);}
	fu operator * (fu const & tmp){return fu(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x);}
	fu operator / (fu const & tmp){double t=tmp.x*tmp.x+tmp.y*tmp.y;return fu((x*tmp.x+y*tmp.y)/t,(y*tmp.x-x*tmp.y)/t);}
}a[N],b[N],tmp[N];
void FFT(fu *f,bool flag){
	for(int i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
	for(int p=2;p<=n;p<<=1){
		fu angle(cos(2*Pi/p),sin(2*Pi/p));
		if(!flag)angle.y*=-1;
		for(int k=0;k<n;k+=p){
			fu buf(1,0);
			for(int i=k;i<p/2+k;i++){
				fu tt=buf*f[p/2+i];
				f[i+p/2]=f[i]-tt;
				f[i]=f[i]+tt;
				buf=buf*angle;
			}
		}
	}
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++)scanf("%lf",&a[i].x);
	for(int i=0;i<=m;i++)scanf("%lf",&b[i].x);
	for(m+=n,n=1;n<=m;n<<=1);
	for(int i=0;i<n;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
	FFT(a,1);FFT(b,1);
	for(int i=0;i<n;i++)a[i]=a[i]*b[i];
	FFT(a,0);
	for(int i=0;i<=m;i++)printf("%d ",(int)(a[i].x/n+0.49));
}

NTT最终代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod=998244353,G=3,N=2200000;
ll n,m,iN,a[N],b[N],tr[N];
void inc(ll &a,ll b){a+=b;if(a>=mod)a-=mod;}
void dec(ll &a,ll b){a-=b;if(a<0)a+=mod;}
ll _inc(ll a,ll b){a+=b;if(a>=mod)a-=mod;return a;}
ll _dec(ll a,ll b){a-=b;if(a<0)a+=mod;return a;}
ll mi(ll a,ll k=mod-2){
	ll sum=1;
	while(k){
		if(k&1)sum=sum*a%mod;
		a=a*a%mod;
		k>>=1;
	}
	return sum;
}
const ll iG=mi(G);
void NTT(ll *f,bool flag){
	for(ll i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
	for(ll p=2;p<=n;p<<=1){
		ll angle=mi(flag?G:iG,(mod-1)/p);
		for(ll k=0;k<n;k+=p){
			ll buf=1;
			for(ll i=k;i<k+p/2;i++){
				ll tt=f[i+p/2]*buf%mod;
				f[i+p/2]=_dec(f[i],tt);
				inc(f[i],tt);
				buf=buf*angle%mod;
			}
		}
	}
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++)scanf("%lld",&a[i]);
	for(ll i=0;i<=m;i++)scanf("%lld",&b[i]);
	for(m+=n,n=1;n<=m;n<<=1);
	for(ll i=0;i<n;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
	NTT(a,1);NTT(b,1);
	for(ll i=0;i<n;i++)a[i]=a[i]*b[i]%mod;
	NTT(a,0);
	iN=mi(n);
	for(ll i=0;i<=m;i++)printf("%lld ",a[i]*iN%mod);
}
posted @ 2020-07-11 18:42  ZTC_ZTC  阅读(157)  评论(0编辑  收藏  举报