「学习笔记」FFT 之优化——NTT

「学习笔记」FFT 之优化——NTT

前言

\(NTT\) 在某种意义上说,应该属于 \(FFT\) 的一种优化。

——因而必备知识肯定要有 \(FFT\) 啦...

如果不知道 \(FFT\) 的大佬可以走这里

引入

\(FFT\) 中,为了能计算单位原根 \(\omega\) ,我们使用了 \(\text{C++}\)math 库中的 \(cos、sin\) 函数,所以我们无法避免地使用了 double 以及其运算。

但是,众所周知的, double 的运算很慢,并且,我们的虚数乘法是类似于下面这种打法:

cplx operator * (const cplx a)const{return cplx(vr*a.vr-vi*a.vi,vr*a.vi+a.vr*vi);}

显然,一次虚数乘法涉及四次 double 的乘法。

并且在运算过程中,会有大量的精度丢失,这都是我们不可接受的。

然而问题来了:我们多项式乘法都是整数在那里搞来搞去,为什么一定要扯到浮点数。是否存在一个在模意义下的,只使用整数的方法?——Tiw_Air_OAO

快速数论变换——NTT

想一想我们使用了单位复根的哪些特性:

  1. \(w_{n}^{i}*w_{n}^{j}=w_{n}^{i+j}\)
  2. \(w_{dn}^{dk}=w_n^k\)
  3. \(w_{2n}^k=-w_{2n}^{k+n}\)
  4. \(n\) 个单位根互不相同,且 \(w_n^0=1\)

那么我们能否在 模意义 下找到一个性质相同的数?

这里有一个同样也是 某某根 的东西,叫做 原根

对于素数 \(p\)\(p\) 的原根 \(G\) 定义为使得 \(G^0,G^1,...,G^{p−2}(mod\space p)\) 互不相同的数。

仔细思考一下,发现 原根单位复根 很像。

同理,我们再定义 \(g_n^k = (G^{\frac{p-1}{n}})^k\) ,这样 \(g_n^k\) 就与 \(\omega_n^k\) 长得更像了...

但是,必须在 \(g_n^k\) 满足与 \(\omega_n^k\) 同样的性质时,我们才能等价替换。

现在,我们检验原根在模意义下是否满足与单位复根同样的性质:

  1. 由幂的运算立即可得
  2. 由幂的运算立即可得
  3. \(g_{2n}^{k+n}=(G^{\frac{p-1}{2n}})^{k+n}=(G^{\frac{p-1}{2n}})^k*(G^{\frac{p-1}{2n}})^n=G^{\frac{p-1}{2}}*g_{2n}^k=-g_{2n}^k(mod\space p)\) ,因为 \((G^{p-1}=1(mod\space p)\) 且由原根定义 \(G^{\frac{p-1}{2}}\not=G^{p-1}(mod\space p)\) ,故 \(G^{\frac{p-1}{2}}=-1(mod\space p)\)
  4. 由原根的定义立即可得

发现原根可以在模意义下 完全替换 单位复根。

这就是 \(NTT\) 了。

但是,这样的方法对模数会有一定的限制

\(m = 2^p*k+1\)\(k\) 为奇数,则多项式长度必须 \(n \le 2^p\)

至于模数以及其原根,没有必要来死记,为什么?

我们程序员就应该干我们经常干的事情——打表可得...

以下是参考代码:

#include<cstdio>
#include<algorithm>
using namespace std;

#define rep(i,__l,__r) for(register int i=__l,i##_end_=__r;i<=i##_end_;++i)
#define fep(i,__l,__r) for(register int i=__l,i##_end_=__r;i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define Endl putchar('\n')
// #define FILEOI
// #define int long long

#ifdef FILEOI
    #define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
    #undef MAXBUFFERSIZE
    #define cg (c=fgetc())
#else
    #define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

const int MAXN=3e6;
const int MOD=998244353,g=3,gi=332748118;
int n,m;
int a[MAXN+5],b[MAXN+5],revi[MAXN+5];
inline int qkpow(int a,int x){
    int ret=1;
    for(;x>0;x>>=1){
         if(x&1)ret=1ll*ret*a%MOD;
        a=1ll*a*a%MOD;
    }
    return ret;
}
inline void ntt(int* f,const short opt=1){
    for(int i=0;i<n;++i)if(i<revi[i])swap(f[i],f[revi[i]]);
    for(int p=2,len,gn,Pow,tmp;p<=n;p<<=1){
        len=p>>1,gn=qkpow(opt==1?g:gi,(MOD-1)/p);
        for(int k=0;k<n;k+=p){Pow=1;
            for(int l=k;l<k+len;++l,Pow=1ll*Pow*gn%MOD){
                tmp=1ll*Pow*f[len+l]%MOD;
                if(f[l]-tmp<0)f[len+l]=f[l]-tmp+MOD;
                else f[len+l]=f[l]-tmp;
                if(f[l]-MOD+tmp>0)f[l]=f[l]-MOD+tmp;
                else f[l]+=tmp;
            }
        }
    }
    if(opt==-1){
        int inv=qkpow(n,MOD-2);
        for(int i=0;i<n;++i)f[i]=1ll*f[i]*inv%MOD;
    }
}
inline void launch(){
    qread(n,m);
    rep(i,0,n)qread(a[i]);
    rep(i,0,m)qread(b[i]);
    for(m+=n,n=1;n<=m;n<<=1);
    for(int i=0;i<n;++i)revi[i]=(revi[i>>1]>>1)|((i&1)?n>>1:0);
    ntt(a),ntt(b);
    for(int i=0;i<n;++i)a[i]=1ll*a[i]*b[i]%MOD;
    ntt(a,-1);
    rep(i,0,m)writc(a[i],' ');
    Endl;
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    launch();
    return 0;
}

一些引申问题及解决方法

假如题目中规定了模数怎么办?还卡 FFT 的精度怎么办?

有两种方法:三模数 NTT 以及 拆系数 FFT (MTT)

三模数 NTT

我们可以选取三个适用于 \(NTT\) 的模数 \(M1,M2,M3\) 进行 \(NTT\) ,用中国剩余定理合并得到 \(x\space mod\space (M1*M2*M3)\) 的值。只要保证 \(x < M1*M2*M3\) 就可以直接输出这个值。

之所以是三模数,因为用三个大小在 \(10^9\) 左右模数对于大部分题目来说就足够了。

但是 \(M1*M2*M3\) 可能非常大怎么办呢?难不成我还要写高精度?其实也可以。

我们列出同余方程组:

\[\begin{cases} x \equiv a_1&\mod m_1\\ x \equiv a_2&\mod m_2\\ x \equiv a_3&\mod m_3\\ \end{cases} \]

用中国剩余定理合并前两个方程组,得到:

\[\begin{cases} x \equiv A&\mod M\\ x \equiv a_3&\mod m_3\\ \end{cases} \]

其中的 \(M\) 满足 \(M = m1*m2 < 10^{18}\)

然后将第一个方程变形得到 \(x = kM + A\) ,代入第二个方程,得到:

\[kM+A \equiv a_3\mod m_3\\ k \equiv (a_3-A)*M^{-1} \mod m_3\\ \]

\(Q = (a_3-A)*M^{-1}\) ,则 \(k = Pm_3 + Q\)

再将上式代入回 \(x = kM + A\) ,得 \(x = (Pm_3 + Q)M+ A = Pm_3M+QM+A\)

又因为 \(M = m_1m_2\) ,所以 \(x = Pm_1m_2m_3 + QM + A\)

也就是说 \(x \equiv QM + A \mod m_1m_2m_3\)

然后,我们完美地解决了这个东西。

接下来是代码:

#include<cstdio>
#include<algorithm>
using namespace std;

#define rep(i,__l,__r) for(register int i=__l,i##_end_=__r;i<=i##_end_;++i)
#define fep(i,__l,__r) for(register int i=__l,i##_end_=__r;i>=i##_end_;--i)
#define writc(a,b) fwrit(a),putchar(b)
#define mp(a,b) make_pair(a,b)
#define ft first
#define sd second
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define Endl putchar('\n')
// #define FILEOI
#define int long long

#ifdef FILEOI
    #define MAXBUFFERSIZE 500000
    inline char fgetc(){
        static char buf[MAXBUFFERSIZE+5],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,MAXBUFFERSIZE,stdin),p1==p2)?EOF:*p1++;
    }
    #undef MAXBUFFERSIZE
    #define cg (c=fgetc())
#else
    #define cg (c=getchar())
#endif
template<class T>inline void qread(T& x){
    char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    if(f)x=-x;
}
inline int qread(){
    int x=0;char c;bool f=0;
    while(cg<'0'||'9'<c)f|=(c=='-');
    for(x=(c^48);'0'<=cg&&c<='9';x=(x<<1)+(x<<3)+(c^48));
    return f?-x:x;
}
template<class T,class... Args>inline void qread(T& x,Args&... args){qread(x),qread(args...);}
template<class T>inline T Max(const T x,const T y){return x>y?x:y;}
template<class T>inline T Min(const T x,const T y){return x<y?x:y;}
template<class T>inline T fab(const T x){return x>0?x:-x;}
inline int gcd(const int a,const int b){return b?gcd(b,a%b):a;}
inline void getInv(int inv[],const int lim,const int MOD){
    inv[0]=inv[1]=1;for(int i=2;i<=lim;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
}
template<class T>void fwrit(const T x){
    if(x<0)return (void)(putchar('-'),fwrit(-x));
    if(x>9)fwrit(x/10);putchar(x%10^48);
}
inline LL mulMod(const LL a,const LL b,const LL mod){//long long multiplie_mod
    return ((a*b-(LL)((long double)a/mod*b+1e-8)*mod)%mod+mod)%mod;
}

inline int qkpow(int a,int n,const int mod){
    int ret=1;
    for(;n>0;n>>=1){
        if(n&1)ret=ret*a%mod;
        a=a*a%mod;
    }
    return ret;
}

const int MAXN=3e5;
const int MOD[3]={469762049ll,998244353ll,1004535809ll};//三模数
const int G=3;//共用的原根
int inv[3][3],k1,k2,Inv,M;
//inv[i][j] : MOD[i] 在 (mod MOD[j]) 下的逆元

int h[3][MAXN+5],g[3][MAXN+5];
//h/g[i][j] : 原函数的第 j 位在 (mod MOD[i]) 的情况下的值

int revi[MAXN+5];//反转数组

inline void init(){
    rep(i,0,2)rep(j,0,2)if(i!=j)//处理 inv 数组, 主要用到费马小定理
        inv[i][j]=qkpow(MOD[i],MOD[j]-2,MOD[j]);
    M=MOD[0]*MOD[1];
    k1=mulMod(MOD[1],inv[1][0],M);
    k2=mulMod(MOD[0],inv[0][1],M);
    Inv=inv[0][2]*inv[1][2]%MOD[2];
}

inline int crt(const int a1,const int a2,const int a3,const int mod){
    int A=(mulMod(a1,k1,M)+mulMod(a2,k2,M))%M;
    int K=(a3+MOD[2]-A%MOD[2])%MOD[2]*Inv%MOD[2];
    return ((M%mod)*K%mod+A)%mod;
}

inline void ntt(int* f,const int n,const int m,const short opt=1){
    /*
        和普通的 ntt 没啥区别, 如果有什么问题, 可以去查查 fft 的资料
        唯一有区别的地方在于取模的时候, 要根据我们目前计算的模数下的运算来取模
    */
    for(int i=0;i<n;++i)if(i<revi[i])swap(f[i],f[revi[i]]);
    for(int s=2;s<=n;s<<=1){
        int t=s>>1,u=(opt==-1)?qkpow(G,(MOD[m]-1)/s,MOD[m]):qkpow(G,MOD[m]-1-(MOD[m]-1)/s,MOD[m]);
        for(int i=0;i<n;i+=s){int w=1;
            for(int j=i;j<i+t;++j,w=w*u%MOD[m]){
                int x=f[j],y=w*f[j+t]%MOD[m];
                f[j]=(x+y)%MOD[m];
                f[j+t]=(x-y+MOD[m])%MOD[m];
            }
        }
    }
    if(opt==-1){
        int inv=qkpow(n,MOD[m]-2,MOD[m]);
        rep(i,0,n-1)f[i]=f[i]*inv%MOD[m];
    }
}

int n,m,p;

inline void launch(){
    init();
    qread(n,m,p);
    rep(i,0,n){//这里我输入的最大的一个模数, 因为其已经超过 1e9 的范围, 刚刚输入时不用取模
        qread(h[2][i]);
        h[1][i]=h[2][i]%MOD[1];
        h[0][i]=h[2][i]%MOD[0];
    }
    rep(i,0,m){
        qread(g[2][i]);
        g[1][i]=g[2][i]%MOD[1];
        g[0][i]=g[2][i]%MOD[0];
    }
    for(m+=n,n=1;n<=m;n<<=1);
    for(int i=0;i<n;++i)revi[i]=(revi[i>>1]>>1)|((i&1)?n>>1:0);
    rep(i,0,2){
        ntt(h[i],n,i),ntt(g[i],n,i);
        rep(j,0,n-1)h[i][j]=h[i][j]*g[i][j]%MOD[i];
        ntt(h[i],n,i,-1);
    }
    for(int i=0;i<=m;++i)
        writc(crt(h[0][i],h[1][i],h[2][i],p),' ');
    //使用 crt(我国剩余定理) 来还原答案
    // rep(i,0,m)printf("%lld %lld %lld\n",h[0][i],h[1][i],h[2][i]);
}

signed main(){
#ifdef FILEOI
    freopen("file.in","r",stdin);
    freopen("file.out","w",stdout);
#endif
    launch();
    return 0;
}

拆系数 FFT (MTT)

我太菜了,还不会...等我更新吧...

posted @ 2019-12-19 20:18  南枙向暖  阅读(529)  评论(0编辑  收藏  举报