NTT 学习笔记
前言
ntt和fft一样,都是用来处理卷积的,但用处不一样
fft因为浮点数的性质,系数的大小没有限制,但是会丢失精度
ntt是通过整数运算在剩余系下计算卷积,卷积后的系数不能超过整形的范围,但是速度较快,而且不掉精
如果系数不大,一般用ntt
如果系数大大,且不能取模,则用fft
理论
原根:一个数g为p的原根,当且仅当$g^{ \phi (p)} \equiv 1 (mod \ p)$
发现原根也满足fft中单位根的四条性质,所以可以用原根代替单位根
我们设$G^i_n = g^{\frac{p-1}{n}*i}$
可以将$w^i_n$ 替换为$G^i_n$
ntt要满足$p=g*2^x+1$,这样$(p-1)/n$才是整数
然后就是fft的过程了
代码
#include <iostream> #include <cstdio> #include <cmath> #define N 4000001 using namespace std; #define mod 998244353 #define int long long int lim,rev[N],len; int inv[N],a[N],b[N]; int read() { char c=getchar(); int x=0,f=1; while(c<'0'||c>'9') { if(c=='-')f=-1;c=getchar(); } while(c>='0'&&c<='9') { x=x*10+c-'0';c=getchar(); } return x*f; } int qpow(int base,int index) { int ans=1; while(index) { if(index&1) ans*=base,ans%=mod; base*=base,base%=mod; index>>=1; } return ans; } void ntt(int arr[],int gen) { for(int i=0;i<lim;i++) if(rev[i]>i) swap(arr[i],arr[rev[i]]); for(int i=1;i<lim;i*=2)//枚举区间长度的一般(方便合并) { int val=qpow(gen,(mod-1)/(i<<1));//相当于计算G(1,i*2) 即相邻根之间的增量 for(int j=0;j<lim;j+=(i<<1))//枚举每个区间 { int val2=1; //每个区间的根要从头开始代入 for(int k=0;k<i;k++,val2*=val,val2%=mod)//计算 { int t=arr[j+k],t2=val2*arr[j+k+i]%mod;//蝴蝶变换 arr[j+k]=(t+t2)%mod; arr[j+k+i]=(t-t2+mod)%mod; } } } } signed main() { int n,m; cin>>n>>m; for(int i=0;i<=n;i++) a[i]=read(); for(int i=0;i<=m;i++) b[i]=read(); lim=1; while(lim<=n+m) len++,lim<<=1; for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1)); ntt(a,3); ntt(b,3); for(int i=0;i<lim;i++) a[i]=a[i]*b[i]%mod; ntt(a,qpow(3,mod-2));//idft带入的是单位根的逆元,这里也相应地带入3的逆元 for(int i=0;i<=n+m;i++) printf("%lld ", a[i]*qpow(lim,mod-2)%mod);//idft最后要除以项数 }
看都看了,顺手点个推荐呗 :)