FFT&NTT学习笔记
具体原理就不讲了qwq,毕竟证明我也不太懂
FFT(快速傅立叶变换)&NTT(快速数论变换)
FFT
1 //求多项式乘积
2 //要求多项式A和多项式B的积多项式C
3 //具体操作就是
4 //DFT(A),DFT(B)->暴力乘积->拉格朗日插值(即IDFT(C))->C
5 //其中DFT表示离散傅里叶变换
6 //通俗的来说就是用点值表示多项式
7 //使用神秘单位复数根将时间复杂度降至O(nlogn)
8 //ps:但是常数巨大
9 //pps:应用非常广泛,非常多题目都要fft or ntt优化,板子一定要背熟
10 #include<iostream>
11 #include<cstring>
12 #include<cstdio>
13 #include<cmath>
14 #define pw(n) (1<<n)
15 using namespace std;
16 const double pi=acos(-1);
17 struct complex{
18 double a,b;
19 complex(double _a=0,double _b=0){
20 a=_a;
21 b=_b;
22 }
23 friend complex operator +(complex x,complex y){return complex(x.a+y.a,x.b+y.b);}
24 friend complex operator -(complex x,complex y){return complex(x.a-y.a,x.b-y.b);}
25 friend complex operator *(complex x,complex y){return complex(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);}
26 friend complex operator *(complex x,double y){return complex(x.a*y,x.b*y);}
27 friend complex operator /(complex x,double y){return complex(x.a/y,x.b/y);}
28 }a[100001],b[100001];
29 int n,m,bit,bitnum=0,rev[pw(20)];
30 void getrev(int l){//Reverse
31 for(int i=0;i<pw(l);i++){
32 rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
33 }
34 }
35 void FFT(complex *s,int op){
36 for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
37 for(int i=1;i<bit;i<<=1){
38 complex w(cos(pi/i),op*sin(pi/i));
39 for(int p=i<<1,j=0;j<bit;j+=p){//Butterfly
40 complex wk(1,0);
41 for(int k=j;k<i+j;k++,wk=wk*w){
42 complex x=s[k],y=wk*s[k+i];
43 s[k]=x+y;
44 s[k+i]=x-y;
45 }
46 }
47 }
48 if(op==-1){
49 for(int i=0;i<=bit;i++){
50 s[i]=s[i]/(double)bit;
51 }
52 }
53 }
54 int main(){
55 scanf("%d%d",&n,&m);
56 for(int i=0;i<=n;i++)scanf("%lf",&a[i].a);
57 for(int i=0;i<=m;i++)scanf("%lf",&b[i].a);
58 m+=n;
59 for(bit=1;bit<=m;bit<<=1)bitnum++;
60 getrev(bitnum);
61 FFT(a,1);
62 FFT(b,1);
63 for(int i=0;i<=bit;i++)a[i]=a[i]*b[i];
64 FFT(a,-1);
65 for(int i=0;i<=m;i++)printf("%d ",(int)(a[i].a+0.5));
66 return 0;
67 }
NTT
1 //费马数数论变换
2 //大家觉得998244353好还是1004535809好?^_^
3 #include<algorithm>
4 #include<iostream>
5 #include<cstring>
6 #include<cstdio>
7 #include<cmath>
8 #define pw(n) (1<<n)
9 using namespace std;
10 const int N=262144,P=998244353,g=3;//或P=1004535809
11 int n,m,bit,bitnum=0,a[N+5],b[N+5],rev[N+5];
12 void getrev(int l){
13 for(int i=0;i<pw(l);i++){
14 rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
15 }
16 }
17 int fastpow(int a,int b){
18 int ans=1;
19 for(;b;b>>=1,a=1LL*a*a%P){
20 if(b&1)ans=1LL*ans*a%P;
21 }
22 return ans;
23 }
24 void NTT(int *s,int op){
25 for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
26 for(int i=1;i<bit;i<<=1){
27 int w=fastpow(g,(P-1)/(i<<1));
28 for(int p=i<<1,j=0;j<bit;j+=p){
29 int wk=1;
30 for(int k=j;k<i+j;k++,wk=1LL*wk*w%P){
31 int x=s[k],y=1LL*s[k+i]*wk%P;
32 s[k]=(x+y)%P;
33 s[k+i]=(x-y+P)%P;
34 }
35 }
36 }
37 if(op==-1){
38 reverse(s+1,s+bit);
39 int inv=fastpow(bit,P-2);
40 for(int i=0;i<bit;i++)a[i]=1LL*a[i]*inv%P;
41 }
42 }
43 int main(){
44 scanf("%d%d",&n,&m);
45 for(int i=0;i<=n;i++)scanf("%d",&a[i]);
46 for(int i=0;i<=m;i++)scanf("%d",&b[i]);
47 m+=n;
48 for(bit=1;bit<=m;bit<<=1)bitnum++;
49 getrev(bitnum);
50 NTT(a,1);
51 NTT(b,1);
52 for(int i=0;i<bit;i++)a[i]=1LL*a[i]*b[i]%P;
53 NTT(a,-1);
54 for(int i=m;i>=0;i--)printf("%d ",a[i]);
55 return 0;
56 }