洛谷4245:【模板】任意模数NTT——题解

https://www.luogu.org/problemnew/show/P4245

给两个多项式,求其乘积,每个系数对p取模。

参考:

代码与部分理解参考https://www.luogu.org/blog/yhzq/solution-p4245

NTT常用模数https://blog.csdn.net/hnust_xx/article/details/76572828

一些有关NTT讲解的东西。

————————————

NTT作用和DFT相同,只是NTT可以取模,且精度误差小。

我们的唯一限制就是取模的质数p=k*2^n+1,因此998244353应运而生。

对于如何构造使得每次变换都会减少一半的长度这个问题和p的原根有关,在这里就不讲了。

然而对于p不确定的时候,我们也可以使用中国剩余定理。

具体来说,找到一些p1,p2……pk满足NTT条件,然后计算结果,最后用中国剩余定理依次消即可。

然而这题很恶心的是很有可能爆longlong,且在模数大于int的情况下也没法快速乘,这时候就要使用骆克强提出的快速乘了(具体可以前往参考处第一篇博客。)

#include<cstdio>
#include<cctype>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
using namespace std;
typedef long long ll;
typedef long double dl;
const int N=5e5+5;
const ll p1=469762049,p2=998244353,p3=1004535809,g=3;
const ll M=p1*p2;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
ll qpow(ll a,ll n,ll p){
    ll res=1;
    while(n){
    if(n&1)res=res*a%p;
    a=a*a%p;n>>=1;
    }
    return res;
}
ll qmulti(ll a,ll b,ll p){
    a%=p,b%=p;
    return ((a*b-(ll)((ll)((dl)a/p*b+0.5)*p))%p+p)%p;
}
void FNT(ll a[],int n,int on,ll p){
    for(int i=1,j=n>>1;i<n-1;i++){
        if(i<j)swap(a[i],a[j]);
        int k=n>>1;
        while(j>=k){j-=k;k>>=1;}
        if(j<k)j+=k;
    }
    for(int i=2;i<=n;i<<=1){
    ll res=qpow(g,(p-1)/i,p);
        for(int j=0;j<n;j+=i){
        ll w=1;
            for(int k=j;k<j+i/2;k++){
                ll u=a[k],t=w*a[k+i/2]%p;
                a[k]=(u+t)%p;
                a[k+i/2]=(u-t+p)%p;
                w=w*res%p;
            }
        }
    }
    if(on==-1){
    ll inv=qpow(n,p-2,p);
    a[0]=a[0]*inv%p;
    for(int i=1;i<=n/2;i++){
        a[i]=a[i]*inv%p;
        if(i!=n-i)a[n-i]=a[n-i]*inv%p;
        swap(a[i],a[n-i]);
    }
    }
}
int n,m,p;
ll a[N],b[N],c[N],d[N],ans[3][N];
int main(){
    n=read(),m=read(),p=read();
    for(int i=0;i<=n;i++)a[i]=read();
    for(int i=0;i<=m;i++)b[i]=read();
    int nn=1;
    while(nn<=n+m)nn<<=1;
    
    memcpy(c,a,sizeof(a));memcpy(d,b,sizeof(b));
    FNT(c,nn,1,p1);FNT(d,nn,1,p1);
    for(int i=0;i<nn;i++)ans[0][i]=c[i]*d[i]%p1;
    memset(c,0,sizeof(c));memset(d,0,sizeof(d));

    memcpy(c,a,sizeof(a));memcpy(d,b,sizeof(b));
    FNT(c,nn,1,p2);FNT(d,nn,1,p2);
    for(int i=0;i<nn;i++)ans[1][i]=c[i]*d[i]%p2;
    memset(c,0,sizeof(c));memset(d,0,sizeof(d));


    memcpy(c,a,sizeof(a));memcpy(d,b,sizeof(b));
    FNT(c,nn,1,p3);FNT(d,nn,1,p3);
    for(int i=0;i<nn;i++)ans[2][i]=c[i]*d[i]%p3;
    memset(c,0,sizeof(c));memset(d,0,sizeof(d));

    FNT(ans[0],nn,-1,p1);
    FNT(ans[1],nn,-1,p2);
    FNT(ans[2],nn,-1,p3);

    for(int i=0;i<=n+m;i++){
    ll A=(qmulti(ans[0][i]*p2%M,qpow(p2%p1,p1-2,p1),M)+
          qmulti(ans[1][i]*p1%M,qpow(p1%p2,p2-2,p2),M))%M;
    ll k=((ans[2][i]-A)%p3+p3)%p3*qpow(M%p3,p3-2,p3)%p3;
    printf("%lld ",((k%p)*(M%p)%p+A%p)%p);
    }
    puts("");
    return 0;
}

+++++++++++++++++++++++++++++++++++++++++++

+本文作者:luyouqi233。               +

+欢迎访问我的博客:http://www.cnblogs.com/luyouqi233/+

+++++++++++++++++++++++++++++++++++++++++++

posted @ 2018-05-07 09:05  luyouqi233  阅读(918)  评论(0编辑  收藏  举报