多项式乘法(FFT,NTT,MTT)

首先从多项式的概念说起。

多项式,就是形如i=0naixi的式子,ai是系数,x是变量,n为多项式的阶/次数。

然后是重要的多项式卷积:定义多项式fg的卷积h为:

hi=j=0ifjgij

其实就是把两个多项式乘起来然后起了个很高级的名字。然后我们拿定义模拟乘法加法运算来计算卷积显然是O(n2)的(以下均假设n,m同阶)。太慢了,不能接受。

于是我们有方法来在O(nlogn)的复杂度内计算多项式的卷积。不过这个一会再说,请先食用一下前置芝士。

  1. 复数(请翻阅数学必修二课本)
  2. 多项式的系数表示和点值表示(这个就是字面意思)
  3. 单位根(这个得说说)

考虑n次方程xn=1的解数。显然在复数域内它有n个解。我们把这n个解放到复平面内,可以发现它正好n等分单位圆。于是我们定义这n个解ωn0,ωn1,,ωnn1就是n次单位根。

显然这个东西有通项公式wni=cos(2πin)+isin(2πin)(以后复数单位i=1就这么写了,和普通的i区分一下)。

然后是单位根的一些重要性质。

  1. ωdndx=ωnx。显然。或者你把单位根的通项搞成e2πin随便搞搞也行。
  2. ωnx=conj(ωnx)conj是共轭。这个也显然,复平面上关于x轴对称一下。
  3. ωnx+n2=ωnx。复平面上逆时针转半圈就是。

其实都是废话。接下来进入正题。

快速傅里叶变换(Fast Fourier Transform,FFT)

Fast Fast TLE

我们考虑到系数表示的极限也就是O(n2)了,所以考虑一个更快的方法:点值表示。这个把要卷起来的两个多项式对应点的点值乘一下就可以了,是O(n)的。于是我们的问题就来到了如何在系数表示和点值表示之间快速转换。

先人们为我们造好了轮子:DFT和它的逆变换IDFT。前一个是把系数转化成点值,后一个是转化回来。它是什么原理?我们来推式子。(抄的学长PDF,不想看的可以直接跳过)

我们还是看原来式子hx=i+j=xfigj。这里换一下表示方法,好写。然后说一句废话变成:

hx=[i+j=x]figj

然后我们单位根的用处就来了,有这样一个柿子,叫单位根反演:

1nx=0n1ωnvx=[n|v]

证明一下:就是分类讨论。vmodn=0时显然,所有的wnvx都是1。然后剩下的情况直接套个等比数列求和公式变成

1n1ωnnv1ωnv

显然是个0。于是我们可以把它带会上式:

hx=[i+jxmodn=0]figj

n是序列长度,取模先不要管。我们继续化:

=i,j1nk=0n1ωnxkωnikωnjkfigj=1nk=0n1ωnxkωnikfiωnjkgj

然后我们发现后面两个东西不就可以O(n)搞吗。所以定义DFT

Fi=j=1n1ωnijfj

然后定义IDFT

fi=1nj=1n1ωnijfj

就是这个过程。当然它是个线性变换,就是把原来多项式的系数当成一个行向量,乘一个单位根组成的范德蒙德矩阵就行了。然后这个模数继续不管,先说明FFT是怎么做到O(nlogn)变换的。

我们首先设要变换的多项式F(x)=i=1naixi。然后对于奇数项和偶数项,我们分开考虑。

F0(x)是原来偶次项的系数组成的多项式,F1(x)是原来奇次项组成的多项式,也就是

F0(x)=i=0n21a2ixi  F1(x)=i=0n21a2i+1xi

那么原先的多项式可以这样表示:

F(x)=F0(x2)+xF1(x2)

然后我们分治递归向下处理这个式子就变成了O(nlogn)。然后IDFT根据我们上面推导的式子,就是把所有ωni变成ωni,然后把结果除以n就行了。所以函数可以写成一个。因为我们是分治,而两边不一样长没法合并,所以要求长度是2n的。次数不够怎么办?高位补0

递归版的我随便扒了一份,没自己写。所以没有注释。想看随便看看,不想看可以跳过。

void fft(int n, complex<double>* buffer, int offset, int step, complex<double>* epsilon)
{
    if(n == 1) return;
    int m = n >> 1;
    fft(m, buffer, offset, step << 1, epsilon);
    fft(m, buffer, offset + step, step << 1, epsilon);
    for(int k = 0; k != m; ++k)
    {
        int pos = 2 * step * k;
        temp[k] = buffer[pos + offset] + epsilon[k * step] * buffer[pos + offset + step];
        temp[k + m] = buffer[pos + offset] - epsilon[k * step] * buffer[pos + offset + step];
    }
 
    for(int i = 0; i != n; ++i)
        buffer[i * step + offset] = temp[i];
}

然后我们发现递归太慢了而且容易爆栈,所以我们需要一个非递归的写法。

我们观察一下递归的时候每个系数的位置情况(以8次多项式为例):

第一次:{x0,x1,x2,x3,x4,x5,x6,x7}

第二次:{x0,x2,x4,x6},{x1,x3,x5,x7}

第三次:{x0,x4},{x2,x6},{x1,x5},{x3,x7}

第四次:{x0},{x4},{x2},{x6},{x1},{x5},{x3},{x7}

我们列个表看一下它们的二进制位。

|000001010011100101110111012345670246|135704|26|15|370|4|2|6|1|5|3|7000100010110001101011111|

发现最终会变成反转之后的数。所以我们可以预处理每个数二进制反转之后的结果,FFT开始前交换一下。

考虑如何预处理这个东西。设这个东西叫r(x),总长度len=2k。假设我们处理r(x)的时候之前的已经都处理好了。我们发现如果把x右移一位,然后反转这个数,那么此时最低位是0,剩下的位是x除个位以外的反转结果。于是我们得到了递推公式:

r(x)=r(x2)2+(xmod2)×len2

然后是迭代版FFT的另一个重要操作:蝴蝶操作。

我们每次合并结果时,为了避免数组覆盖原值导致错误,我们用临时变量存储原值来进行操作(其实就两句代码,代码一眼出)。

当然还有单位根的问题。每次现算单位根太慢了,有没有什么快的方法?

当然你可以预处理。但是我们有更优秀的方法:每次只算一遍单位根,然后迭代出想要的单位根。具体的仍然看代码。

接下来这份代码在洛谷的板子里跑了1.44s。比较优秀了。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const double pi=acos(-1);
struct cp{
    double r,i;
    cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
    cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
    cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
    cp conj(cp s){return cp{s.r,-s.i};}
}a[2100010],b[2100010];//加减乘 共轭的复数类 够用了 还有数组要开两倍
int n,m,wl=1,r[2100010];
void get(int n){
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));//预处理反转操作结果
}
void fft(cp a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){//枚举区间中点
        cp wn={cos(pi/mid),tp*sin(pi/mid)};//一个单位根
        for(int j=0;j<n;j+=mid<<1){//当前到哪个位置
            cp w={1,0};
            for(int k=0;k<mid;k++,w=w*wn){//左半部分 每次迭代出单位根
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i].r/=n;//idft最后除以n
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
    for(int i=0;i<=m;i++)scanf("%lf",&b[i].r);
    get(n+m);
    fft(a,wl,1);fft(b,wl,1);
    for(int i=0;i<wl;i++)a[i]=a[i]*b[i];
    fft(a,wl,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].r+0.5));//四舍五入
}

这时候我们可以解释前面 modn 的问题了。我们下标 modn 之后实际上求的是循环卷积,于是原来下标为 i+j 的会算到 i+jmodn 上。解决方法很简单,长度开两倍然后不管,相当于高位全是 0 。这样模数就是 2n ,没有影响。

然后FFT还有一个广为人知的优化:三次变两次优化。

我们看我们原来的FFT代码,两次DFT,一次IDFT,一共三次。我们可以利用复数的一些性质把它变成两次。

原理大概是把原先要卷积的两个多项式一个放到实部,一个放到虚部。举个例子,设两个多项式分别为F(x)=i=0nfixi,G(x)=i=0mgixi。我们构造一个多项式H(x),使得hi=fi+igi,然后求它的平方,结果的虚部除以2就是答案。

因为我们有

(a+bi)2=(a2b2)+2abi

所以是对的。

上个代码,洛谷神机差不多1.1s。

void fft(cp a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        cp wn={cos(pi/mid),tp*sin(pi/mid)};
        for(int j=0;j<n;j+=mid<<1){
            cp w={1,0};
            for(int k=0;k<mid;k++,w=w*wn){
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i].i/=2*n;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
    for(int i=0;i<=m;i++)scanf("%lf",&a[i].i);
    get(n+m);
    fft(a,wl,1);
    for(int i=0;i<wl;i++)a[i]=a[i]*a[i];
    fft(a,wl,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].i+0.5));
}

然而三次变两次有没有局限性呢?有的。在两边系数值域相差太大的时候精度严重掉。原因仍然显然。修正方法也不难,把两个多项式数乘一下,值域相同就行了。别忘了除回去。

快速数论变换(Number Theory Transform,NTT)

实际上我们一般不会用FFT,因为缺点很明显:一堆double,不光跑得慢而且会炸精度。那还有什么方法优化呢?我们发现,数论里有个东西和单位根的性质很类似。这个东西叫原根。(忘记原根定义的去oiwiki翻翻)

我们看看它和单位根有什么类似性质。

  1. g是素数p的原根,则g,g2,g3,,gp1modp意义下两两不同。
  2. (gp12n)2=gp1n,而ω2n2=ωn
    然后你把原根带进我们需要用的单位根的性质里会发现都成立。就它了。

然而我们观察式子发现我们必须要保证n|(p1)。所以这个对p的取值还有要求。我们一般选998244353,它的原根是3,而且998244353=7×17×223+1。够用了。

所以直接把上面的代码所有的单位根换成原根就行了。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const int mod=998244353,g=3,invg=332748118;
int a[2100010],b[2100010];
int n,m,inv,wl=1,r[2100010];
void get(int n){
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*a*ans%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
void ntt(int a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        int wn=qpow(tp==1?g:invg,(mod-1)/(mid<<1));
        for(int j=0;j<n;j+=mid<<1){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*wn%mod){
                int x=a[j+k],y=1ll*w*a[j+mid+k]%mod;
                a[j+k]=(x+y)%mod;a[j+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%d",&a[i]);
    for(int i=0;i<=m;i++)scanf("%d",&b[i]);
    get(n+m);inv=qpow(wl,mod-2);
    ntt(a,wl,1);ntt(b,wl,1);
    for(int i=0;i<wl;i++)a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,wl,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",a[i]);
}

当然还有一些其他的NTT模数,比如1004535809=479×221+1,原根为3。还有一个是469762049,原根也是3

任意模数NTT(MTT)

MTT,也就是任意模数NTT(其实也不用NTT,用FFT)。

FFT可以处理任意模数,但是值域较大的时候不光会爆longlong还会丢精。NTT可以处理大值域,但是要求模数是2n+1的形式。现在要求值域较大的任意模数乘法。

首先你当然可以选三个NTT模数然后CRT合并。但是这个九次NTT的做法常数要多大有多大,所以一般没人用。

然后是我们的主题,MTT,也就是拆系数FFT。

具体地讲,我们把两个多项式的系数拆成两部分分别处理(我以215为界)。设我们拆的单位是M,我们将系数拆成了kM+b两部分,得到了F1,F2,G1,G2四个多项式。则我们的答案就是:

(F1M+F2)(G1M+G2)=F1G1M2+(F1G2+F2G1)M+F2G2

如果我们暴力算四个多项式乘法就是12次FFT,更慢了。

如果我们稍微动点脑子,分别将四个多项式DFT然后乘法之后IDFT回来,这是8次FFT,还是很慢。

我们将DFT和IDFT的部分分开优化。DFT的部分,我们可以使用三次变两次优化,将两次DFT变成一次。具体的,如果我们要对F(x),G(x)两个多项式DFT,那么我们设两个多项式:

P(x)=F(x)+iG(x)

Q(x)=F(x)iG(x)

我们直接对P做DFT,然后就可以通过共轭求出Q的DFT,解一个方程组就得到了F,G的点值表示。

然后是IDFT的部分。我们之前得到了F0,F1的点值表示,我们把它们左右同乘一个P(x)=G0(x)+iG1(x),就变成了

F0(x)P(x)=F0(x)G0(x)+iF0(x)G1(x)

F1(x)P(x)=F1(x)G0(x)+iF1(x)G1(x)

将它转回系数表示之后提出四个实部虚部,就得到了我们想要的四个多项式卷积。加和即可。

上个代码,写写注释。记得开long double不然只有50分。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const long double pi=acos(-1);
const int sq=(1<<15)-1;
struct cp{
    long double r,i;
    cp(long double a=0,long double b=0){r=a;i=b;}
    cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
    cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
    cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
}a[300010],b[300010],p[300010],q[300010];
int n,m,mod,wl=1,r[300010],ans[300010];
void get(int n){
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
void fft(cp a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        cp wn={cos(pi/mid),tp*sin(pi/mid)};
        for(int j=0;j<n;j+=mid<<1){
            cp w={1,0};
            for(int k=0;k<mid;k++,w=w*wn){
                cp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
    if(tp^1)for(int i=0;i<n;i++)a[i].r/=n,a[i].i/=n;
}
int main(){
    scanf("%d%d%d",&n,&m,&mod);
    for(int i=0;i<=n;i++){
        int x;scanf("%d",&x);
        a[i]=cp (x&sq,x>>15);//拆分系数
    }
    for(int i=0;i<=m;i++){
        int x;scanf("%d",&x);
        b[i]=cp (x&sq,x>>15);
    }
    get(n+m);
    fft(a,wl,1);fft(b,wl,1);
    for(int i=0;i<wl;i++){
        int ret=(wl-i)&(wl-1);/*解释一下这个东西
        首先我们点值表示的每个下标是当前这个单位根的取值
        然后这个相当于0不反转 其他数i翻转成wl-i 即第i个单位根共轭处的取值 所以是对的*/
        p[i]=(cp){0.5*(a[i].r+a[ret].r),0.5*(a[i].i-a[ret].i)}*b[i];//这是解方程之后的结果
        q[i]=(cp){0.5*(a[i].i+a[ret].i),0.5*(a[ret].r-a[i].r)}*b[i];
    }
    fft(p,wl,-1);fft(q,wl,-1);
    for(int i=0;i<wl;i++){
        long long p1=p[i].r+0.5,q1=p[i].i+0.5,x=q[i].r+0.5,y=q[i].i+0.5;
        ans[i]=(p1%mod+((q1+x)%mod<<15)+((y%mod)<<30))%mod;//按照公式代入即可
    }
    for(int i=0;i<=n+m;i++)printf("%d ",ans[i]);
}
posted @   gtm1514  阅读(587)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· Manus爆火,是硬核还是营销?
· 一文读懂知识蒸馏
· 终于写完轮子一部分:tcp代理 了,记录一下
点击右上角即可分享
微信分享提示