fft,ntt

在被两题卡了常数之后,花了很久优化了自己的模板

现在的一般来说任意模数求逆1s跑3e5,exp跑1e5是没啥问题的(自己电脑,可能比luogu慢一倍)

当模数是$998244353,1004535809,9985661441$的时候(这$3$个的原根都是$3$)

我们会使用$ntt$来求解

$ntt$的模板本身常数不大 优化效果不明显

const int mo=998244353;
const int G=3;
IL int fsp(int x,int y)
{
    ll now=1;
    while (y)
    {
        if (y&1) now=now*x%mo;
        x=1ll*x*x%mo;
        y>>=1;
    }
    return now;
}
IL void ntt_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
IL void clear()
{
    for (int i=0;i<=n;i++) a[i]=b[i]=0;
}
void ntt(int *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
    {
        int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
        rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
        for (int j=0;j<n;j+=(i*2))
          for (int k=0;k<i;k++)
          {
              int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
              a[j+k]=(x+y)%mo; a[i+j+k]=(x-y)%mo;
          }
    }
    if (o==-1)
    {
      reverse(&a[1],&a[n]);
      for (int i=0,inv=fsp(n,mo-2);i<n;i++)
        a[i]=1ll*a[i]*inv%mo;
    }
}
IL void getcj(int *A,int *B,int len)
{
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=a[i];
    clear();
}
 

当模数不为这$3$个,我们就需要$mtt$来实现

而$mtt$的实现为用$mx$的方法将数的实部虚部分别放$x \& 65536,x(>>15)$

另外一个重要的地方是要预处理出$w$,我们采用指针来存,避免使用vector

代码$p$的初始值为$2*n$

所有数组大小为$4*n$

$getcj$的时候要先把数组中的负数变正

IL void clear()
{
    for (int i=0;i<=n;i++) a[i].a=a[i].b=b[i].a=b[i].b=c[i].a=c[i].b=d[i].a=d[i].b=0;
}
cp *w[N],tmp[N*2];
int p;
IL void init()
{
    cp *now=tmp;
    for (int i=1;i<=p;i<<=1)
    {
        w[i]=now;
        for (int j=0;j<i;j++) w[i][j]=(cp){cos(pi*j/i),sin(pi*j/i)};
        now+=i;
    }
}
IL void fft_init()
{
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1));
}
void fft(cp *a,int o)
{
    for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
    for (int i=1;i<n;i<<=1)
        for (int j=0;j<n;j+=(i*2))
        {
          cp *x1=a+j,*x2=a+i+j,*W=w[i];
          for (int k=0;k<i;k++,x1++,x2++,W++)
          {
              cp x=*x1,y=(cp){(*W).a,(*W).b*o}*(*x2); 
              *x1=x+y,*x2=x-y;
          }
        }
    if (o==-1) for(int i=0;i<n;i++) a[i].a/=n;
}
IL void getcj(int *A,int *B,int len)
{
    rep(i,0,len)
    {
        A[i]=(A[i]+mo)%mo,B[i]=(B[i]+mo)%mo;
    }
    for (int i=0;i<len;i++)
    {
       a[i]=(cp){A[i]&32767,A[i]>>15};
       b[i]=(cp){B[i]&32767,B[i]>>15};
    }
    m=len*2; fft_init();
    fft(a,1); fft(b,1);
    for (int i=0;i<n;i++)
    {
        int j=(n-1)&(n-i);
        c[j]=(cp){0.5*(a[i].a+a[j].a),0.5*(a[i].b-a[j].b)}*b[i];
        d[j]=(cp){0.5*(a[i].b+a[j].b),0.5*(a[j].a-a[i].a)}*b[i];
    }
    fft(c,1); fft(d,1);
    double inv=ee/n;
    rep(i,0,n) c[i].a*=inv,c[i].b*=inv;
    rep(i,0,n) d[i].a*=inv,d[i].b*=inv;
    rep(i,0,len)
    {
        ll a1=c[i].a+0.5,a2=c[i].b+0.5;
        ll a3=d[i].a+0.5,a4=d[i].b+0.5;
        B[i]=(a1+((a2+a3)%mo<<15)+((a4%mo)<<30))%mo;
    }
    clear();
}

对于其他的多项式函数

用$fft$还是$ntt$是差不多的(除了数组类型)

posted @ 2018-12-14 09:01  尹吴潇  阅读(297)  评论(0编辑  收藏  举报