多项式乘法(FFT,NTT,MTT)

首先从多项式的概念说起。

多项式,就是形如\(\sum_{i=0}^na_ix^i\)的式子,\(a_i\)是系数,\(x\)是变量,\(n\)为多项式的阶/次数。

然后是重要的多项式卷积:定义多项式\(f\)\(g\)的卷积\(h\)为:

\[h_i=\sum_{j=0}^if_jg_{i-j} \]

其实就是把两个多项式乘起来然后起了个很高级的名字。然后我们拿定义模拟乘法加法运算来计算卷积显然是\(O(n^2)\)的(以下均假设\(n,m\)同阶)。太慢了,不能接受。

于是我们有方法来在\(O(n\log n)\)的复杂度内计算多项式的卷积。不过这个一会再说,请先食用一下前置芝士。

  1. 复数(请翻阅数学必修二课本)
  2. 多项式的系数表示和点值表示(这个就是字面意思)
  3. 单位根(这个得说说)

考虑\(n\)次方程\(x^n=1\)的解数。显然在复数域内它有\(n\)个解。我们把这\(n\)个解放到复平面内,可以发现它正好\(n\)等分单位圆。于是我们定义这\(n\)个解\(\omega_n^0,\omega_n^1,\cdots ,\omega_n^{n-1}\)就是\(n\)次单位根。

显然这个东西有通项公式\(w_n^i=\cos(\frac {2\pi i}{n})+\text i\sin(\frac {2\pi i}{n})\)(以后复数单位\(\text i=\sqrt {-1}\)就这么写了,和普通的\(i\)区分一下)。

然后是单位根的一些重要性质。

  1. \(\omega_{dn}^{dx}=\omega_n^x\)。显然。或者你把单位根的通项搞成\(e^{\frac {2\pi i}{n}}\)随便搞搞也行。
  2. \(\omega_n^x=\text {conj}(\omega_n^{-x})\)\(\text{conj}\)是共轭。这个也显然,复平面上关于\(x\)轴对称一下。
  3. \(\omega_n^{x+\frac n2}=-\omega_n^x\)。复平面上逆时针转半圈就是。

其实都是废话。接下来进入正题。

快速傅里叶变换(Fast Fourier Transform,FFT)

Fast Fast TLE

我们考虑到系数表示的极限也就是\(O(n^2)\)了,所以考虑一个更快的方法:点值表示。这个把要卷起来的两个多项式对应点的点值乘一下就可以了,是\(O(n)\)的。于是我们的问题就来到了如何在系数表示和点值表示之间快速转换。

先人们为我们造好了轮子:\(\text {DFT}\)和它的逆变换\(\text {IDFT}\)。前一个是把系数转化成点值,后一个是转化回来。它是什么原理?我们来推式子。(抄的学长PDF,不想看的可以直接跳过)

我们还是看原来式子\(h_x=\sum_{i+j=x}f_ig_j\)。这里换一下表示方法,好写。然后说一句废话变成:

\[h_x=\sum[i+j=x]f_ig_j \]

然后我们单位根的用处就来了,有这样一个柿子,叫单位根反演:

\[\frac 1n\sum_{x=0}^{n-1}\omega_n^{vx}=[n|v] \]

证明一下:就是分类讨论。\(v\mod n=0\)时显然,所有的\(w_n^{vx}\)都是\(1\)。然后剩下的情况直接套个等比数列求和公式变成

\[\frac 1n\frac {1-\omega_n^{nv}}{1-\omega_n^v} \]

显然是个\(0\)。于是我们可以把它带会上式:

\[h_x=\sum[i+j-x\mod n=0]f_ig_j \]

\(n\)是序列长度,取模先不要管。我们继续化:

\[\begin{aligned} =&\sum_{i,j}\frac 1n\sum_{k=0}^{n-1}\omega_n^{-xk}\omega_n^{ik}\omega_n^{jk}f_ig_j\\ =&\frac 1n\sum_{k=0}^{n-1}\omega_n^{-xk}\sum\omega_n^{ik}f_i\sum\omega_n^{jk}g_j \end{aligned} \]

然后我们发现后面两个东西不就可以\(O(n)\)搞吗。所以定义\(\text {DFT}\)

\[F_i=\sum_{j=1}^{n-1}\omega_n^{ij}f_j \]

然后定义\(\text{IDFT}\)

\[f_i=\frac 1n\sum_{j=1}^{n-1}\omega_n^{-ij}f_j \]

就是这个过程。当然它是个线性变换,就是把原来多项式的系数当成一个行向量,乘一个单位根组成的范德蒙德矩阵就行了。然后这个模数继续不管,先说明FFT是怎么做到\(O(n\log n)\)变换的。

我们首先设要变换的多项式\(F(x)=\sum_{i=1}^na_ix^i\)。然后对于奇数项和偶数项,我们分开考虑。

\(F_0(x)\)是原来偶次项的系数组成的多项式,\(F_1(x)\)是原来奇次项组成的多项式,也就是

\[F_0(x)=\sum_{i=0}^{\frac n2-1}a_{2i}x^i\ \ F_1(x)=\sum_{i=0}^{\frac n2-1}a_{2i+1}x^i \]

那么原先的多项式可以这样表示:

\[F(x)=F_0(x^2)+xF_1(x^2) \]

然后我们分治递归向下处理这个式子就变成了\(O(n\log n)\)。然后\(\text {IDFT}\)根据我们上面推导的式子,就是把所有\(\omega_n^i\)变成\(\omega_n^{-i}\),然后把结果除以\(n\)就行了。所以函数可以写成一个。因为我们是分治,而两边不一样长没法合并,所以要求长度是\(2^n\)的。次数不够怎么办?高位补\(0\)

递归版的我随便扒了一份,没自己写。所以没有注释。想看随便看看,不想看可以跳过。

void fft(int n, complex<double>* buffer, int offset, int step, complex<double>* epsilon)
{
    if(n == 1) return;
    int m = n >> 1;
    fft(m, buffer, offset, step << 1, epsilon);
    fft(m, buffer, offset + step, step << 1, epsilon);
    for(int k = 0; k != m; ++k)
    {
        int pos = 2 * step * k;
        temp[k] = buffer[pos + offset] + epsilon[k * step] * buffer[pos + offset + step];
        temp[k + m] = buffer[pos + offset] - epsilon[k * step] * buffer[pos + offset + step];
    }
 
    for(int i = 0; i != n; ++i)
        buffer[i * step + offset] = temp[i];
}

然后我们发现递归太慢了而且容易爆栈,所以我们需要一个非递归的写法。

我们观察一下递归的时候每个系数的位置情况(以\(8\)次多项式为例):

第一次:\(\{x_0,x_1,x_2,x_3,x_4,x_5,x_6,x_7\}\)

第二次:\(\{x_0,x_2,x_4,x_6\},\{x_1,x_3,x_5,x_7\}\)

第三次:\(\{x_0,x_4\},\{x_2,x_6\},\{x_1,x_5\},\{x_3,x_7\}\)

第四次:\(\{x_0\},\{x_4\},\{x_2\},\{x_6\},\{x_1\},\{x_5\},\{x_3\},\{x_7\}\)

我们列个表看一下它们的二进制位。

\[\begin{vmatrix}000&&001&&010&&011&&100&&101&&110&&111\\0&&1&&2&&3&&4&&5&&6&&7\\0&&2&&4&&6&|&1&&3&&5&&7\\0&&4&|&2&&6&|&1&&5&|&3&&7\\0&|&4&|&2&|&6&|&1&|&5&|&3&|&7\\000&&100&&010&&110&&001&&101&&011&&111\end{vmatrix} \]

发现最终会变成反转之后的数。所以我们可以预处理每个数二进制反转之后的结果,FFT开始前交换一下。

考虑如何预处理这个东西。设这个东西叫\(r(x)\),总长度\(len=2^k\)。假设我们处理\(r(x)\)的时候之前的已经都处理好了。我们发现如果把\(x\)右移一位,然后反转这个数,那么此时最低位是\(0\),剩下的位是\(x\)除个位以外的反转结果。于是我们得到了递推公式:

\[r(x)=\lfloor \frac{r(\lfloor \frac x2\rfloor)}{2}\rfloor+(x\mod 2)\times \frac {len}2 \]

然后是迭代版FFT的另一个重要操作:蝴蝶操作。

我们每次合并结果时,为了避免数组覆盖原值导致错误,我们用临时变量存储原值来进行操作(其实就两句代码,代码一眼出)。

当然还有单位根的问题。每次现算单位根太慢了,有没有什么快的方法?

当然你可以预处理。但是我们有更优秀的方法:每次只算一遍单位根,然后迭代出想要的单位根。具体的仍然看代码。

接下来这份代码在洛谷的板子里跑了1.44s。比较优秀了。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const double pi=acos(-1);
struct cp{
    double r,i;
    cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
    cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
    cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
    cp conj(cp s){return cp{s.r,-s.i};}
}a[2100010],b[2100010];//加减乘 共轭的复数类 够用了 还有数组要开两倍
int n,m,wl=1,r[2100010];
void get(int n){
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));//预处理反转操作结果
}
void fft(cp a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){//枚举区间中点
        cp wn={cos(pi/mid),tp*sin(pi/mid)};//一个单位根
        for(int j=0;j<n;j+=mid<<1){//当前到哪个位置
            cp w={1,0};
            for(int k=0;k<mid;k++,w=w*wn){//左半部分 每次迭代出单位根
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i].r/=n;//idft最后除以n
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
    for(int i=0;i<=m;i++)scanf("%lf",&b[i].r);
    get(n+m);
    fft(a,wl,1);fft(b,wl,1);
    for(int i=0;i<wl;i++)a[i]=a[i]*b[i];
    fft(a,wl,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].r+0.5));//四舍五入
}

这时候我们可以解释前面 \(\mod n\) 的问题了。我们下标 \(\mod n\) 之后实际上求的是循环卷积,于是原来下标为 \(i+j\) 的会算到 \(i+j\mod n\) 上。解决方法很简单,长度开两倍然后不管,相当于高位全是 \(0\) 。这样模数就是 \(2n\) ,没有影响。

然后FFT还有一个广为人知的优化:三次变两次优化。

我们看我们原来的FFT代码,两次DFT,一次IDFT,一共三次。我们可以利用复数的一些性质把它变成两次。

原理大概是把原先要卷积的两个多项式一个放到实部,一个放到虚部。举个例子,设两个多项式分别为\(F(x)=\sum_{i=0}^nf_ix^i,G(x)=\sum_{i=0}^mg_ix^i\)。我们构造一个多项式\(H(x)\),使得\(h_i=f_i+\text ig_i\),然后求它的平方,结果的虚部除以\(2\)就是答案。

因为我们有

\[(a+b\text i)^2=(a^2-b^2)+2ab\text i \]

所以是对的。

上个代码,洛谷神机差不多1.1s。

void fft(cp a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        cp wn={cos(pi/mid),tp*sin(pi/mid)};
        for(int j=0;j<n;j+=mid<<1){
            cp w={1,0};
            for(int k=0;k<mid;k++,w=w*wn){
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i].i/=2*n;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
    for(int i=0;i<=m;i++)scanf("%lf",&a[i].i);
    get(n+m);
    fft(a,wl,1);
    for(int i=0;i<wl;i++)a[i]=a[i]*a[i];
    fft(a,wl,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].i+0.5));
}

然而三次变两次有没有局限性呢?有的。在两边系数值域相差太大的时候精度严重掉。原因仍然显然。修正方法也不难,把两个多项式数乘一下,值域相同就行了。别忘了除回去。

快速数论变换(Number Theory Transform,NTT)

实际上我们一般不会用FFT,因为缺点很明显:一堆double,不光跑得慢而且会炸精度。那还有什么方法优化呢?我们发现,数论里有个东西和单位根的性质很类似。这个东西叫原根。(忘记原根定义的去oiwiki翻翻)

我们看看它和单位根有什么类似性质。

  1. \(g\)是素数\(p\)的原根,则\(g,g^2,g^3,\cdots,g^{p-1}\)\(\mod p\)意义下两两不同。
  2. \((g^{\frac {p-1}{2n}})^2=g^{\frac {p-1}n}\),而\(\omega_{2n}^2=\omega_n\)
    然后你把原根带进我们需要用的单位根的性质里会发现都成立。就它了。

然而我们观察式子发现我们必须要保证\(n|(p-1)\)。所以这个对\(p\)的取值还有要求。我们一般选\(998244353\),它的原根是\(3\),而且\(998244353=7\times 17\times 2^{23}+1\)。够用了。

所以直接把上面的代码所有的单位根换成原根就行了。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const int mod=998244353,g=3,invg=332748118;
int a[2100010],b[2100010];
int n,m,inv,wl=1,r[2100010];
void get(int n){
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*a*ans%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
void ntt(int a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        int wn=qpow(tp==1?g:invg,(mod-1)/(mid<<1));
        for(int j=0;j<n;j+=mid<<1){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*wn%mod){
                int x=a[j+k],y=1ll*w*a[j+mid+k]%mod;
                a[j+k]=(x+y)%mod;a[j+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%d",&a[i]);
    for(int i=0;i<=m;i++)scanf("%d",&b[i]);
    get(n+m);inv=qpow(wl,mod-2);
    ntt(a,wl,1);ntt(b,wl,1);
    for(int i=0;i<wl;i++)a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,wl,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",a[i]);
}

当然还有一些其他的NTT模数,比如\(1004535809=479\times 2^{21}+1\),原根为\(3\)。还有一个是\(469762049\),原根也是\(3\)

任意模数NTT(MTT)

MTT,也就是任意模数NTT(其实也不用NTT,用FFT)。

FFT可以处理任意模数,但是值域较大的时候不光会爆longlong还会丢精。NTT可以处理大值域,但是要求模数是\(2^n+1\)的形式。现在要求值域较大的任意模数乘法。

首先你当然可以选三个NTT模数然后CRT合并。但是这个九次NTT的做法常数要多大有多大,所以一般没人用。

然后是我们的主题,MTT,也就是拆系数FFT。

具体地讲,我们把两个多项式的系数拆成两部分分别处理(我以\(2^{15}\)为界)。设我们拆的单位是\(M\),我们将系数拆成了\(kM+b\)两部分,得到了\(F_1,F_2,G_1,G_2\)四个多项式。则我们的答案就是:

\[(F_1M+F_2)(G_1M+G_2)=F_1G_1M^2+(F_1G_2+F_2G_1)M+F_2G_2 \]

如果我们暴力算四个多项式乘法就是12次FFT,更慢了。

如果我们稍微动点脑子,分别将四个多项式DFT然后乘法之后IDFT回来,这是8次FFT,还是很慢。

我们将DFT和IDFT的部分分开优化。DFT的部分,我们可以使用三次变两次优化,将两次DFT变成一次。具体的,如果我们要对\(F(x),G(x)\)两个多项式DFT,那么我们设两个多项式:

\[P(x)=F(x)+\text iG(x) \]

\[Q(x)=F(x)-\text iG(x) \]

我们直接对\(P\)做DFT,然后就可以通过共轭求出\(Q\)的DFT,解一个方程组就得到了\(F,G\)的点值表示。

然后是IDFT的部分。我们之前得到了\(F_0,F_1\)的点值表示,我们把它们左右同乘一个\(P(x)=G_0(x)+\text iG_1(x)\),就变成了

\[F_0(x)P(x)=F_0(x)G_0(x)+\text iF_0(x)G_1(x) \]

\[F_1(x)P(x)=F_1(x)G_0(x)+\text iF_1(x)G_1(x) \]

将它转回系数表示之后提出四个实部虚部,就得到了我们想要的四个多项式卷积。加和即可。

上个代码,写写注释。记得开long double不然只有50分。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const long double pi=acos(-1);
const int sq=(1<<15)-1;
struct cp{
    long double r,i;
    cp(long double a=0,long double b=0){r=a;i=b;}
    cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
    cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
    cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
}a[300010],b[300010],p[300010],q[300010];
int n,m,mod,wl=1,r[300010],ans[300010];
void get(int n){
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
void fft(cp a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        cp wn={cos(pi/mid),tp*sin(pi/mid)};
        for(int j=0;j<n;j+=mid<<1){
            cp w={1,0};
            for(int k=0;k<mid;k++,w=w*wn){
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i].r/=n,a[i].i/=n;
}
int main(){
    scanf("%d%d%d",&n,&m,&mod);
    for(int i=0;i<=n;i++){
        int x;scanf("%d",&x);
        a[i]=cp (x&sq,x>>15);//拆分系数
    }
    for(int i=0;i<=m;i++){
        int x;scanf("%d",&x);
        b[i]=cp (x&sq,x>>15);
    }
    get(n+m);
    fft(a,wl,1);fft(b,wl,1);
    for(int i=0;i<wl;i++){
        int ret=(wl-i)&(wl-1);/*解释一下这个东西
        首先我们点值表示的每个下标是当前这个单位根的取值
        然后这个相当于0不反转 其他数i翻转成wl-i 即第i个单位根共轭处的取值 所以是对的*/
        p[i]=(cp){0.5*(a[i].r+a[ret].r),0.5*(a[i].i-a[ret].i)}*b[i];//这是解方程之后的结果
        q[i]=(cp){0.5*(a[i].i+a[ret].i),0.5*(a[ret].r-a[i].r)}*b[i];
    }
    fft(p,wl,-1);fft(q,wl,-1);
    for(int i=0;i<wl;i++){
        long long p1=p[i].r+0.5,q1=p[i].i+0.5,x=q[i].r+0.5,y=q[i].i+0.5;
        ans[i]=(p1%mod+((q1+x)%mod<<15)+((y%mod)<<30))%mod;//按照公式代入即可
    }
    for(int i=0;i<=n+m;i++)printf("%d ",ans[i]);
}
posted @ 2022-09-03 19:50  gtm1514  阅读(472)  评论(0编辑  收藏  举报