Loading

【学习笔记】任意模数多项式乘法

Page Views Count

三模数 NTT

由于多数 NTT 的操作对应值域 \(10^9\),规模 \(10^5\),所以选取三个常用 NTT 模数 \(p_1=998244353\)\(p_2=1004535809\)\(p_3=469702049\) 做三次乘法也就是九次 NTT。

三个模数的乘积大于结果的理论最大值,所以可以 CRT 合并得到原数再取模。使用 EXCRT 可以不开 __int128

EXCRT 具体过程是先把前两个结果 \(h_{1,i}\)\(h_{2,i}\)\(p_1\)\(p_2\) 下合并,解得一个 \(k\),使得 \(h_i\equiv kp_1+h_{1,i}\pmod {p_1p_2}\),之后拿这个值去和 \(h_{3,i}\)\(p_1p_2\)\(p_3\) 下合并,解得一个 \(k'\) 使得 \(h_i\equiv k'p_1p_2+kp_1+h_{1,i}\),这个数对给定的 \(p\) 取模即可。

常数极大。

点击查看代码
inline int q_pow(int A,int B,int P){
    int res=1;
    while(B){
        if(B&1) res=1ll*res*A%P;
        A=1ll*A*A%P;
        B>>=1;
    }
    return res;
}
inline ll exgcd(ll A,ll B,ll &X,ll &Y){
    if(!B){
        X=1,Y=0;
        return A;
    }
    ll D=exgcd(B,A%B,Y,X);
    Y-=A/B*X;
    return D;
}

int rev[maxn];
int base,w[maxn];
struct Poly{
    const static int g=3;
    int deg;
    vector<ull> f;
    ull& operator[](const int &i){return f[i];}
    ull operator[](const int &i)const{return f[i];}
    inline void set(int L){deg=L;f.resize(L);}
    inline void clear(int L,int R){for(int i=L;i<=R;++i)f[i]=0;}
    inline void output(int L){for(int i=0;i<L;++i)printf("%llu ",f[i]);printf("\n");}
    inline void NTT(int L,bool type,int P){
        set(L);
        int inv_g=q_pow(g,P-2,P);
        for(int i=1;i<L;++i){
            rev[i]=(rev[i>>1]>>1)+(i&1?L>>1:0);
            if(i<rev[i]) swap(f[i],f[rev[i]]);
        }
        for(int d=1;d<L;d<<=1){
            base=q_pow(type?g:inv_g,(P-1)/(2*d),P);
            w[0]=1;
            for(int i=1;i<d;++i) w[i]=1ll*w[i-1]*base%P;
            for(int i=0;i<L;i+=d<<1){
                for(int j=0;j<d;++j){
                    ull x=f[i+j],y=f[i+d+j]*w[j]%P;
                    f[i+j]=x+y,f[i+d+j]=x-y+P;
                }
            }
        }
        for(int i=0;i<L;++i) f[i]%=P;
        if(!type){
            int inv_L=q_pow(L,P-2,P);
            for(int i=0;i<L;++i) f[i]=f[i]*inv_L%P;
        }
    }
}F,G,H[3];

int n,m,p;
int a[maxn],b[maxn],c[maxn];
ll mod[3]={998244353,1004535809,469762049};

inline int solve(ll A,ll B,ll C){
    ll X1,Y1,X2,Y2;
    exgcd(mod[0],mod[1],X1,Y1);
    X1=(X1%mod[1]+mod[1])%mod[1];
    X1=((B-A)%mod[1]+mod[1])%mod[1]*X1%mod[1];
    exgcd(mod[0]*mod[1],mod[2],X2,Y2);
    X2=(X2%mod[2]+mod[2])%mod[2];
    X2=((C-(X1*mod[0]+A)%(mod[0]*mod[1]))%mod[2]+mod[2])%mod[2]*X2%mod[2];
    return (X2%p*mod[0]%p*mod[1]%p+X1%p*mod[0]%p+A%p)%p;

}

int main(){
    n=read(),m=read(),p=read();
    for(int i=0;i<=n;++i) a[i]=read();
    for(int i=0;i<=m;++i) b[i]=read();
    int L=1;
    while(L<n+m+1) L<<=1;
    F.set(L),G.set(L);
    for(int i=0;i<=2;++i){
        H[i].set(L);
        F.clear(0,L-1),G.clear(0,L-1);
        for(int j=0;j<=n;++j) F[j]=a[j];
        for(int j=0;j<=m;++j) G[j]=b[j];
        F.NTT(L,1,mod[i]),G.NTT(L,1,mod[i]);
        for(int j=0;j<L;++j) H[i][j]=F[j]*G[j]%mod[i];
        H[i].NTT(L,0,mod[i]);
    }
    for(int i=0;i<=n+m;++i) c[i]=solve((ll)H[0][i],(ll)H[1][i],(ll)H[2][i]);
    for(int i=0;i<=n+m;++i) printf("%d ",c[i]);
    printf("\n");
    return 0;
}

拆系数 FFT

通过把原多项式系数拆开来保证 FFT 的精度。

有一个好写且不掉精度的 \(5\) 次 FFT 做法。

\(B=\sqrt{p}\),把两个多项式拆成:\(F(x)=B\times F_1(x)+F_2(x),G(x)=B\times G_1(x)+G_2(x)\)

这样卷积的结果是 \(H(x)=B^2\times F_1(x)G_1(x)+B\times (F_1(x)G_2(x)+F_2(x)G_1(x))+F_2(x)G_2(x)\)

\(T(x)=G_1(x)+i\times G_2(x)\),那么 \(F_1(x)T(x)\)\(F_2(x)T(x)\) 的实部虚部就分别对应上面的四个卷积结果。

这样只需要对 \(F_1(x),F_2(x),T(x)\) 做 FFT,对 \(F_1(x)T(x)\)\(F_2(x)T(x)\) 做 IFFT,\(5\) 次就可以了。

为了保证精度可以预处理单位根配合 long double

点击查看代码
int rev[maxn];
struct Complex{
    db a,b;
    Complex()=default;
    Complex(db a_,db b_):a(a_),b(b_){}
    Complex operator+(const Complex &rhs)const{return Complex(a+rhs.a,b+rhs.b);}
    Complex operator-(const Complex &rhs)const{return Complex(a-rhs.a,b-rhs.b);}
    Complex operator*(const Complex &rhs)const{return Complex(a*rhs.a-b*rhs.b,a*rhs.b+b*rhs.a);}
}base,W[maxn],w[maxn];
struct Poly{
    int deg;
    vector<Complex> f;
    Complex& operator[](const int &i){return f[i];}
    Complex operator[](const int &i)const{return f[i];}
    inline void set(int L){deg=L;f.resize(L);}
    inline void clear(int L,int R){for(int i=L;i<=R;++i)f[i]=Complex(0,0);}
    inline void output(int L){for(int i=0;i<L;++i)printf("(%Lf,%Lf) ",f[i].a,f[i].b);printf("\n");}
    inline void FFT(int L,bool type){
        set(L);
        for(int i=1;i<L;++i){
            rev[i]=(rev[i>>1]>>1)+(i&1?L>>1:0);
            if(i<rev[i]) swap(f[i],f[rev[i]]);
        }
        for(int d=1;d<L;d<<=1){
            for(int i=0,j=0;i<d;++i,j+=L/(2*d)) w[i]=W[type?j:L-j];
            for(int i=0;i<L;i+=d<<1){
                for(int j=0;j<d;++j){
                    Complex x=f[i+j],y=w[j]*f[i+d+j];
                    f[i+j]=x+y,f[i+d+j]=x-y;
                }
            }
        }
        if(!type){
            for(int i=0;i<L;++i) f[i].a/=L,f[i].b/=L;
        }
    }
}A,B,T,F,G;

int n,m,p,C;

int main(){
    n=read(),m=read(),p=read(),C=sqrt(p);
    int L=1;
    while(L<n+m+1) L<<=1;
    for(int i=0;i<=L;++i) W[i]=Complex(cos(i*2*pi/L),sin(i*2*pi/L));
    A.set(L),B.set(L),T.set(L),F.set(L),G.set(L);
    for(int i=0;i<=n;++i){
        int x=read()%p;
        A[i]=Complex(1.0*(x/C),0),B[i]=Complex(1.0*(x%C),0);
    }
    for(int i=0;i<=m;++i){
        int x=read()%p;
        T[i]=Complex(1.0*(x/C),1.0*(x%C));
    }
    A.FFT(L,1),B.FFT(L,1),T.FFT(L,1);
    for(int i=0;i<L;++i) F[i]=A[i]*T[i],G[i]=B[i]*T[i];
    F.FFT(L,0),G.FFT(L,0);
    for(int i=0;i<=n+m;++i){
        int now=0;
        now=(now+1ll*C*C%p*((ll)(F[i].a+0.5)%p)%p)%p;
        now=(now+1ll*C*((ll)(F[i].b+0.5)%p+(ll)(G[i].a+0.5)%p)%p+p)%p;
        now=(now+(ll)(G[i].b+0.5)%p+p)%p;
        printf("%d ",now);
    }
    printf("\n");
    return 0;
}

参考资料

posted @ 2023-07-16 20:58  SoyTony  阅读(62)  评论(2编辑  收藏  举报