Luogu4245 【模板】任意模数NTT

https://www.luogu.com.cn/problem/P4245

三模数\(NTT\)

\(p^2 \max\{n,m\} \approx 10^{23}\),利用三个模数可以通过中国剩余定理唯一确定其值。

具体操作流程:

\[\begin{cases} x \equiv x_1 \bmod p_1\\ x \equiv x_2 \bmod p_2\\ x \equiv x_3 \bmod p_3\\ \end{cases} \\ M=p_1 p_2\\ A \equiv x_1 \times p_2 \times {p_2}^{-1}[\bmod p_1意义下]+x_2 \times p_1 \times {p_1}^{-1}[\bmod p_2意义下] \bmod M\\ \begin{cases} x \equiv A \bmod M\\ x \equiv x_3 \bmod p_3\\ \end{cases} \\ x_3 \equiv kM+A \bmod p_3\\ k \equiv (x_3-A)M^{-1} \bmod p_3\\ \]

第一步合并可以利用龟速乘,然后可以计算最终答案。

NTT模数

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 400005
#define ll long long
using namespace std;
int f[N],g[N];
int n,m,p;
struct ntt
{
    /*
    469762049=7*2^26+1  gi=3
    998244353=119*2^23+1  gi=3
    1004535809=479*2^21+1  gi=3
    */
    int f[N],g[N];
    int n,m,p,tw,gi;
    int G[2][35];
    int s,l,rev[N];
    void Set(int cx,int cy,int x,int y,int Gi,int *a,int *b)
    {
        n=cx,m=cy,p=x,tw=y,gi=Gi;
        memcpy(f,a,cx*sizeof(int));
        memcpy(g,b,cy*sizeof(int));
    }
    void Add(int &x,int y)
    {
        x=(x+y)%p;
    }
    void Del(int &x,int y)
    {
        x=(x-y+p)%p;
    }
    void Mul(int &x,int y)
    {
        x=(ll)x*y%p;
    }
    int add(int x,int y)
    {
        return (x+y)%p;
    }
    int del(int x,int y)
    {
        return (x-y+p)%p;
    }
    int mul(int x,int y)
    {
        return (ll)x*y%p;
    }
    int ksm(int x,int y)
    {
        int ans=1;
        while (y)
        {
            if (y & 1)
                Mul(ans,x);
            Mul(x,x);
            y >>=1;
        }
        return ans;
    }
    void Pre()
    {
        G[0][tw]=ksm(gi,(p-1)/(1 << tw));
        G[1][tw]=ksm(G[0][tw],p-2);
        for (int i=tw-1;i;--i)
        {
            G[0][i]=mul(G[0][i+1],G[0][i+1]);
            G[1][i]=mul(G[1][i+1],G[1][i+1]);
        }
    }
    void NTT(int *a,int t)
    {
        for (int i=0;i<s;++i)
            if (i<rev[i])
                swap(a[i],a[rev[i]]);
        for (int mid=1,o=1;mid<s;mid <<=1,++o)
            for (int j=0;j<s;j+=mid << 1)
            {
                int g=1;
                for (int k=0;k<mid;++k,Mul(g,G[t][o]))
                {
                    int x=a[j+k],y=mul(g,a[j+k+mid]);
                    a[j+k]=add(x,y),a[j+k+mid]=del(x,y);
                }
            }
    }
    void solve()
    {
        Pre();
        s=1,l=0;
        while (s<n+m)
            s <<=1,++l;
        for (int i=0;i<s;++i)
            rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << l-1);
        NTT(f,0),NTT(g,0);
        for (int i=0;i<s;++i)
            Mul(f[i],g[i]);
        NTT(f,1);
        int invs=ksm(s,p-2);
        for (int i=0;i<s;++i)
            Mul(f[i],invs);
    }
}NTT1,NTT2,NTT3;
ll low_mul(ll x,ll y,ll p)
{
    ll ans=0;
    while (y)
    {
        if (y & 1)
            ans=(ans+x)%p;
        x=(x << 1) % p;
        y >>=1;
    }
    return ans;
}
int main()
{
    scanf("%d%d%d",&n,&m,&p),++n,++m;
    for (int i=0;i<n;++i)
        scanf("%d",&f[i]);
    for (int i=0;i<m;++i)
        scanf("%d",&g[i]);
    int p1=469762049,p2=998244353,p3=1004535809;
    NTT1.Set(n,m,p1,26,3,f,g);
    NTT2.Set(n,m,p2,23,3,f,g);
    NTT3.Set(n,m,p3,21,3,f,g);
    NTT1.solve(),NTT2.solve(),NTT3.solve();
    int inv1=NTT1.ksm(p2,p1-2),inv2=NTT2.ksm(p1,p2-2);
    ll M=(ll)p1*p2;
    int inv3=NTT3.ksm(M%p3,p3-2);
    for (int i=0;i<n+m-1;++i)
    {
        ll x=(low_mul((ll)NTT1.f[i]*p2%M,inv1,M)+low_mul((ll)NTT2.f[i]*p1%M,inv2,M))%M;
        ll k=(ll)inv3*(NTT3.f[i]-x%p3+p3)%p3;
        int ans=(M%p*k%p+x%p+p)%p;
        printf("%d ",ans);
    }
    putchar('\n');
    return 0;
}

拆系数\(FFT\)

把大数字拆成最后\(15\)位(二进制下)和前面的位,直接\(FFT\),然后合并。

\[(2^{15} a_i+b_i)(2^{15}c_i+d_i)=\\ 2^{30}a_i c_i+2^{15} (a_i d_i +b_i c_i)+b_i d_i \]

计算一下就好了。

注意,本题对精度要求很高,尽量减少计算次数(一开始爆\(0\),还以为\(FFT\)炸了,傻了好久\(QAQ\))。

最好用\(long \quad double\),虽然\(double\)可以过,运行较快,但还是很慌的(极容易\(WA\))。

\(Code:\)

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#define N 400005
#define D long double
#define ll long long
using namespace std;
int n,m,p,q,l,s,rev[N];
const D Pi=acos(-1.0);
ll g1,g2,g3,ans;
struct virt
{
    D x,y;
    virt (D xx=0.0,D yy=0.0)
    {
        x=xx,y=yy;
    }
    virt operator + (virt b)
    {
        return virt(x+b.x,y+b.y);
    }
    virt operator - (virt b)
    {
        return virt(x-b.x,y-b.y);
    }
    virt operator * (virt b)
    {
        return virt(x*b.x-y*b.y,x*b.y+y*b.x);
    }
}q1,q2,q3,q4,a[N],b[N],c[N],d[N],g[2][25],W[N];
void FFT(virt *a,int t)
{
	for (int i=0;i<s;i++)
		if (i<rev[i])
			swap(a[i],a[rev[i]]);
	for (int mid=1,o=1;mid<s;mid <<=1,o++)
		for (int j=0;j<s;j+=(mid << 1))
			for (int k=0;k<mid;k++)
			{
				virt x=a[j+k],y=W[mid+k]*a[j+k+mid];
				a[j+k]=x+y;
				a[j+k+mid]=x-y;
			}
    if (t==-1)
    {
        reverse(a+1,a+s);
        D k=1.0/s;
        for (int i=0;i<s;i++)
            a[i]=a[i]*k;
    }
}
int main()
{
    scanf("%d%d%d",&n,&m,&p);
    for (int i=0;i<=n;i++)
        scanf("%d",&q),a[i].x=q >> 15,b[i].x=q & ((1 << 15) - 1);
    for (int i=0;i<=m;i++)
        scanf("%d",&q),c[i].x=q >> 15,d[i].x=q & ((1 << 15) - 1);
    l=0,s=1;
    while (s<=n+m)
    {
        s <<=1;
        l++;
    }
    for (int i=0;i<s;i++)
        rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for(int i=1;i<s;i<<=1)
        for (int k=0;k<i;k++)
            W[i+k]=virt(cos(Pi*k/i),sin(Pi*k/i));
    FFT(a,1);
    FFT(b,1);
    FFT(c,1);
    FFT(d,1);
    for (int i=0;i<s;i++)
    {
        q1=a[i],q2=b[i],q3=c[i],q4=d[i];
        a[i]=q1*q3;
        b[i]=q1*q4+q2*q3;
        c[i]=q2*q4;
    }
    FFT(a,-1);
    FFT(b,-1);
    FFT(c,-1);
    for (int i=0;i<=n+m;i++)
    {
        g1=(ll)(a[i].x+0.5);
        g2=(ll)(b[i].x+0.5);
        g3=(ll)(c[i].x+0.5);
        g1=((g1%p) << 30)%p;
        g2=((g2%p) << 15)%p;
        g3=g3%p;
        ans=((g1+g2+g3)%p+p)%p;
        printf("%lld ",ans);
    }
    putchar('\n');
    return 0;
}
posted @ 2020-08-03 16:49  GK0328  阅读(99)  评论(0编辑  收藏  举报