FFT/NTT学习笔记
非常推荐这一篇,挺详细的
拉格朗日差值:https://www.zhihu.com/question/58333118
FFT: https://www.luogu.com.cn/blog/command-block/fft-xue-xi-bi-ji
NTT: https://www.luogu.com.cn/blog/command-block/ntt-yu-duo-xiang-shi-quan-jia-tong
原根表(NTT要用,还要背。。。):http://blog.miskcoo.com/2014/07/fft-prime-table
注意一个点:虚数的千万不要写错!!!
struct fu{
fu(double xx=0,double yy=0){x=xx;y=yy;}
double x,y;
fu operator + (fu const & tmp){return fu(x+tmp.x,y+tmp.y);}
fu operator - (fu const & tmp){return fu(x-tmp.x,y-tmp.y);}
fu operator * (fu const & tmp){return fu(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x);}
fu operator / (fu const & tmp){double t=tmp.x*tmp.x+tmp.y*tmp.y;return fu((x*tmp.x+y*tmp.y)/t,(y*tmp.x-x*tmp.y)/t);}
};
FFT最终代码:
#include<bits/stdc++.h>
#define N 7700000
using namespace std;
const double Pi=acos(-1);
int n,m,tr[N];
struct fu{
fu(double xx=0,double yy=0){x=xx;y=yy;}
double x,y;
fu operator + (fu const & tmp){return fu(x+tmp.x,y+tmp.y);}
fu operator - (fu const & tmp){return fu(x-tmp.x,y-tmp.y);}
fu operator * (fu const & tmp){return fu(x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x);}
fu operator / (fu const & tmp){double t=tmp.x*tmp.x+tmp.y*tmp.y;return fu((x*tmp.x+y*tmp.y)/t,(y*tmp.x-x*tmp.y)/t);}
}a[N],b[N],tmp[N];
void FFT(fu *f,bool flag){
for(int i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
for(int p=2;p<=n;p<<=1){
fu angle(cos(2*Pi/p),sin(2*Pi/p));
if(!flag)angle.y*=-1;
for(int k=0;k<n;k+=p){
fu buf(1,0);
for(int i=k;i<p/2+k;i++){
fu tt=buf*f[p/2+i];
f[i+p/2]=f[i]-tt;
f[i]=f[i]+tt;
buf=buf*angle;
}
}
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lf",&a[i].x);
for(int i=0;i<=m;i++)scanf("%lf",&b[i].x);
for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
FFT(a,1);FFT(b,1);
for(int i=0;i<n;i++)a[i]=a[i]*b[i];
FFT(a,0);
for(int i=0;i<=m;i++)printf("%d ",(int)(a[i].x/n+0.49));
}
NTT最终代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod=998244353,G=3,N=2200000;
ll n,m,iN,a[N],b[N],tr[N];
void inc(ll &a,ll b){a+=b;if(a>=mod)a-=mod;}
void dec(ll &a,ll b){a-=b;if(a<0)a+=mod;}
ll _inc(ll a,ll b){a+=b;if(a>=mod)a-=mod;return a;}
ll _dec(ll a,ll b){a-=b;if(a<0)a+=mod;return a;}
ll mi(ll a,ll k=mod-2){
ll sum=1;
while(k){
if(k&1)sum=sum*a%mod;
a=a*a%mod;
k>>=1;
}
return sum;
}
const ll iG=mi(G);
void NTT(ll *f,bool flag){
for(ll i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
for(ll p=2;p<=n;p<<=1){
ll angle=mi(flag?G:iG,(mod-1)/p);
for(ll k=0;k<n;k+=p){
ll buf=1;
for(ll i=k;i<k+p/2;i++){
ll tt=f[i+p/2]*buf%mod;
f[i+p/2]=_dec(f[i],tt);
inc(f[i],tt);
buf=buf*angle%mod;
}
}
}
}
int main(){
scanf("%lld%lld",&n,&m);
for(ll i=0;i<=n;i++)scanf("%lld",&a[i]);
for(ll i=0;i<=m;i++)scanf("%lld",&b[i]);
for(m+=n,n=1;n<=m;n<<=1);
for(ll i=0;i<n;i++)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
NTT(a,1);NTT(b,1);
for(ll i=0;i<n;i++)a[i]=a[i]*b[i]%mod;
NTT(a,0);
iN=mi(n);
for(ll i=0;i<=m;i++)printf("%lld ",a[i]*iN%mod);
}