FFT迭代加深 & NTT 多项式求逆
NTT板子
又重温了一遍,大佬说背锅就好
具体看代码
想要看懂NTT板子,先看懂FFT迭代加深模板;
FFT迭代加深版本
#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
const int N=1e7+7;
struct complex{
double x,y;
complex(double xx=0,double yy=0) {x=xx,y=yy;}
}a[N],b[N];
const double pi=acos(-1.0);
complex operator +(const complex a,complex b) {return complex(a.x+b.x,a.y+b.y);}
complex operator -(const complex a,complex b) {return complex(a.x-b.x,a.y-b.y);}
complex operator *(const complex a,complex b) {return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int limit=1,n,m,l;
int r[N];
void FFT(complex *a,int f){
for(int i=0;i<limit;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<limit;mid<<=1){//枚举要合并的区间的长度
complex Wn=complex(cos(pi/mid),f*sin(pi/mid));//单位根
for(int R=mid<<1,j=0;j<limit;j+=R){
complex w(1,0);
for(int k=0;k<mid;k++,w=w*Wn){
complex x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;
a[j+mid+k]=x-y;
}
}
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i].x);
for(int i=0;i<=m;i++) scanf("%lf",&b[i].x);
while(limit<=n+m) limit<<=1,l++;
for(int i=0;i<limit;i++){
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
NTT(a,1);
NTT(b,1);
for(int i=0;i<=limit;i++) a[i]=a[i]*b[i];
NTT(a,-1);
for(int i=0;i<=n+m;i++){
cout<<(int)(a[i].x/(limit)+0.5)<<" ";
}
}
多项式求逆
#include<iostream>
#include<cstdio>
using namespace std;
#define int long long
const int N=1e6+7;
const int p=998244353;//原根为3
int n;
int a[N],b[N],c[N],r[N];
int ksm(int a,int b){
int res=1;
for(;b;b>>=1){
if(b&1) res=res*a%p;
a=a*a%p;
}
return res;
}
void NTT(int *a,int len,int opt){
for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int h=1;h<len;h<<=1){
int Wn=ksm(3,(p-1)/(h<<1));
if(opt==-1) Wn=ksm(Wn,(p-2));//NTT求原根
for(int j=0;j<len;j+=(h<<1)){
int w=1;
for(int k=0;k<h;k++){
int x=a[j+k];
int y=w*a[j+h+k] % p;
a[j+k]=(x+y)%p;
a[j+h+k]=(x-y+p)%p;
w=w*Wn%p;
}
}
}
if(opt==-1){
int inv=ksm(len,p-2);
for(int i=0;i<len;i++){
a[i]=a[i]*inv%p;
}
}
}
void INV(int n,int *a,int *b){
if(n==1){
b[0]=ksm(a[0],p-2);
return;
}
INV((n+1)>>1,a,b);//向上取整
int limit=1,l=0;
while(limit<(n<<1)) limit<<=1,l++;
for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<n;i++) c[i]=a[i];//a数组不能改变,所以赋值
for(int i=n;i<limit;i++) c[i]=0;//其余对答案没用;
NTT(c,limit,1),NTT(b,limit,1);
for(int i=0;i<limit;i++){
b[i]=(1LL*2*b[i]%p-1LL*b[i]*b[i]%p*c[i]%p+p)%p;
}
NTT(b,limit,-1);
for(int i=n;i<limit;i++) b[i]=0;
}
signed main(){
scanf("%lld",&n);
for(int i=0;i<n;i++) scanf("%lld",&a[i]);
INV(n,a,b);
for(int i=0;i<n;i++) cout<<(b[i]%p+p)%p<<" ";
}