任意模数NTT

这并不是真正的任意模数NTT,只是一种奇技淫巧,但是由于码量小而且有效,所以写在这里

在卷积问题中,如果我们要求对答案取模,而且答案不取模会爆long long,但模数原根并不好甚至不是质数,这该怎么办呢?

直接提出一种方法:取一个阈值M,将原本的一个多项式拆分成两个多项式,系数分别为$A_{i}/M$和$A_{i}%M$,然后将两个多项式变成四个多项式互相卷积即可

为了确保精度使用long double

模板题在这里

详见代码

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ld long double
#define ll long long
using namespace std;
const ld pi=acos(-1.0);
const int siz=(1<<21)+5;
const ll M=32768;
struct cp
{
    ld x,y;
};
int to[siz];
ll p;
int n,m;
int lim=1,l;
ll A[siz],B[siz]; 
cp operator + (cp &a,cp &b)
{
    return (cp){a.x+b.x,a.y+b.y};
}
cp operator - (cp &a,cp &b)
{
    return (cp){a.x-b.x,a.y-b.y};
}
cp operator * (cp &a,cp &b)
{
    return (cp){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
void FFT(cp *a,int len,int k)
{
    for(int i=0;i<len;i++)if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        cp w0=(cp){cos(pi/i),k*sin(pi/i)};
        for(int j=0;j<len;j+=(i<<1))
        {
            cp w=(cp){1,0};
            for(int o=0;o<i;o++,w=w*w0)
            {
                cp w1=a[j+o],w2=a[j+o+i]*w;
                a[j+o]=w1+w2,a[j+o+i]=w1-w2;
            }
        }
    }
}
cp a[siz],b[siz],c[siz],d[siz],e[siz],f[siz],g[siz],h[siz];
ll ret[siz];
void MTT()
{
    for(int i=0;i<=n;i++)a[i].x=A[i]/M,b[i].x=A[i]%M;
    for(int i=0;i<=m;i++)c[i].x=B[i]/M,d[i].x=B[i]%M;
    FFT(a,lim,1),FFT(b,lim,1),FFT(c,lim,1),FFT(d,lim,1);
    for(int i=0;i<lim;i++)e[i]=a[i]*c[i],f[i]=a[i]*d[i],g[i]=b[i]*c[i],h[i]=b[i]*d[i];
    FFT(e,lim,-1),FFT(f,lim,-1),FFT(g,lim,-1),FFT(h,lim,-1);
    for(int i=0;i<lim;i++)ret[i]=((ll)(e[i].x/lim+0.1)%p*M%p*M%p+(ll)(f[i].x/lim+0.1)%p*M%p+(ll)(g[i].x/lim+0.1)%p*M%p+(ll)(h[i].x/lim+0.1)%p)%p;
}
int main()
{
    scanf("%d%d%lld",&n,&m,&p);
    while(lim<=2*max(n,m))lim<<=1,l++;
    for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1)));
    for(int i=0;i<=n;i++)scanf("%lld",&A[i]);
    for(int i=0;i<=m;i++)scanf("%lld",&B[i]);
    MTT();
    for(int i=0;i<=n+m;i++)printf("%lld ",ret[i]);
    printf("\n");
    return 0;
}

 

posted @ 2019-06-04 20:28  lleozhang  Views(592)  Comments(0Edit  收藏  举报
levels of contents