多项式乘法(FFT,NTT,MTT)
首先从多项式的概念说起。
多项式,就是形如\(\sum_{i=0}^na_ix^i\)的式子,\(a_i\)是系数,\(x\)是变量,\(n\)为多项式的阶/次数。
然后是重要的多项式卷积:定义多项式\(f\)和\(g\)的卷积\(h\)为:
其实就是把两个多项式乘起来然后起了个很高级的名字。然后我们拿定义模拟乘法加法运算来计算卷积显然是\(O(n^2)\)的(以下均假设\(n,m\)同阶)。太慢了,不能接受。
于是我们有方法来在\(O(n\log n)\)的复杂度内计算多项式的卷积。不过这个一会再说,请先食用一下前置芝士。
- 复数(请翻阅数学必修二课本)
- 多项式的系数表示和点值表示(这个就是字面意思)
- 单位根(这个得说说)
考虑\(n\)次方程\(x^n=1\)的解数。显然在复数域内它有\(n\)个解。我们把这\(n\)个解放到复平面内,可以发现它正好\(n\)等分单位圆。于是我们定义这\(n\)个解\(\omega_n^0,\omega_n^1,\cdots ,\omega_n^{n-1}\)就是\(n\)次单位根。
显然这个东西有通项公式\(w_n^i=\cos(\frac {2\pi i}{n})+\text i\sin(\frac {2\pi i}{n})\)(以后复数单位\(\text i=\sqrt {-1}\)就这么写了,和普通的\(i\)区分一下)。
然后是单位根的一些重要性质。
- \(\omega_{dn}^{dx}=\omega_n^x\)。显然。或者你把单位根的通项搞成\(e^{\frac {2\pi i}{n}}\)随便搞搞也行。
- \(\omega_n^x=\text {conj}(\omega_n^{-x})\)。\(\text{conj}\)是共轭。这个也显然,复平面上关于\(x\)轴对称一下。
- \(\omega_n^{x+\frac n2}=-\omega_n^x\)。复平面上逆时针转半圈就是。
其实都是废话。接下来进入正题。
快速傅里叶变换(Fast Fourier Transform,FFT)
Fast Fast TLE
我们考虑到系数表示的极限也就是\(O(n^2)\)了,所以考虑一个更快的方法:点值表示。这个把要卷起来的两个多项式对应点的点值乘一下就可以了,是\(O(n)\)的。于是我们的问题就来到了如何在系数表示和点值表示之间快速转换。
先人们为我们造好了轮子:\(\text {DFT}\)和它的逆变换\(\text {IDFT}\)。前一个是把系数转化成点值,后一个是转化回来。它是什么原理?我们来推式子。(抄的学长PDF,不想看的可以直接跳过)
我们还是看原来式子\(h_x=\sum_{i+j=x}f_ig_j\)。这里换一下表示方法,好写。然后说一句废话变成:
然后我们单位根的用处就来了,有这样一个柿子,叫单位根反演:
证明一下:就是分类讨论。\(v\mod n=0\)时显然,所有的\(w_n^{vx}\)都是\(1\)。然后剩下的情况直接套个等比数列求和公式变成
显然是个\(0\)。于是我们可以把它带会上式:
\(n\)是序列长度,取模先不要管。我们继续化:
然后我们发现后面两个东西不就可以\(O(n)\)搞吗。所以定义\(\text {DFT}\)为
然后定义\(\text{IDFT}\)为
就是这个过程。当然它是个线性变换,就是把原来多项式的系数当成一个行向量,乘一个单位根组成的范德蒙德矩阵就行了。然后这个模数继续不管,先说明FFT是怎么做到\(O(n\log n)\)变换的。
我们首先设要变换的多项式\(F(x)=\sum_{i=1}^na_ix^i\)。然后对于奇数项和偶数项,我们分开考虑。
设\(F_0(x)\)是原来偶次项的系数组成的多项式,\(F_1(x)\)是原来奇次项组成的多项式,也就是
那么原先的多项式可以这样表示:
然后我们分治递归向下处理这个式子就变成了\(O(n\log n)\)。然后\(\text {IDFT}\)根据我们上面推导的式子,就是把所有\(\omega_n^i\)变成\(\omega_n^{-i}\),然后把结果除以\(n\)就行了。所以函数可以写成一个。因为我们是分治,而两边不一样长没法合并,所以要求长度是\(2^n\)的。次数不够怎么办?高位补\(0\)。
递归版的我随便扒了一份,没自己写。所以没有注释。想看随便看看,不想看可以跳过。
void fft(int n, complex<double>* buffer, int offset, int step, complex<double>* epsilon)
{
if(n == 1) return;
int m = n >> 1;
fft(m, buffer, offset, step << 1, epsilon);
fft(m, buffer, offset + step, step << 1, epsilon);
for(int k = 0; k != m; ++k)
{
int pos = 2 * step * k;
temp[k] = buffer[pos + offset] + epsilon[k * step] * buffer[pos + offset + step];
temp[k + m] = buffer[pos + offset] - epsilon[k * step] * buffer[pos + offset + step];
}
for(int i = 0; i != n; ++i)
buffer[i * step + offset] = temp[i];
}
然后我们发现递归太慢了而且容易爆栈,所以我们需要一个非递归的写法。
我们观察一下递归的时候每个系数的位置情况(以\(8\)次多项式为例):
第一次:\(\{x_0,x_1,x_2,x_3,x_4,x_5,x_6,x_7\}\)
第二次:\(\{x_0,x_2,x_4,x_6\},\{x_1,x_3,x_5,x_7\}\)
第三次:\(\{x_0,x_4\},\{x_2,x_6\},\{x_1,x_5\},\{x_3,x_7\}\)
第四次:\(\{x_0\},\{x_4\},\{x_2\},\{x_6\},\{x_1\},\{x_5\},\{x_3\},\{x_7\}\)
我们列个表看一下它们的二进制位。
发现最终会变成反转之后的数。所以我们可以预处理每个数二进制反转之后的结果,FFT开始前交换一下。
考虑如何预处理这个东西。设这个东西叫\(r(x)\),总长度\(len=2^k\)。假设我们处理\(r(x)\)的时候之前的已经都处理好了。我们发现如果把\(x\)右移一位,然后反转这个数,那么此时最低位是\(0\),剩下的位是\(x\)除个位以外的反转结果。于是我们得到了递推公式:
然后是迭代版FFT的另一个重要操作:蝴蝶操作。
我们每次合并结果时,为了避免数组覆盖原值导致错误,我们用临时变量存储原值来进行操作(其实就两句代码,代码一眼出)。
当然还有单位根的问题。每次现算单位根太慢了,有没有什么快的方法?
当然你可以预处理。但是我们有更优秀的方法:每次只算一遍单位根,然后迭代出想要的单位根。具体的仍然看代码。
接下来这份代码在洛谷的板子里跑了1.44s。比较优秀了。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const double pi=acos(-1);
struct cp{
double r,i;
cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
cp conj(cp s){return cp{s.r,-s.i};}
}a[2100010],b[2100010];//加减乘 共轭的复数类 够用了 还有数组要开两倍
int n,m,wl=1,r[2100010];
void get(int n){
while(n>=wl)wl<<=1;
for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));//预处理反转操作结果
}
void fft(cp a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){//枚举区间中点
cp wn={cos(pi/mid),tp*sin(pi/mid)};//一个单位根
for(int j=0;j<n;j+=mid<<1){//当前到哪个位置
cp w={1,0};
for(int k=0;k<mid;k++,w=w*wn){//左半部分 每次迭代出单位根
cp x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;a[j+mid+k]=x-y;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i].r/=n;//idft最后除以n
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++)scanf("%lf",&b[i].r);
get(n+m);
fft(a,wl,1);fft(b,wl,1);
for(int i=0;i<wl;i++)a[i]=a[i]*b[i];
fft(a,wl,-1);
for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].r+0.5));//四舍五入
}
这时候我们可以解释前面 \(\mod n\) 的问题了。我们下标 \(\mod n\) 之后实际上求的是循环卷积,于是原来下标为 \(i+j\) 的会算到 \(i+j\mod n\) 上。解决方法很简单,长度开两倍然后不管,相当于高位全是 \(0\) 。这样模数就是 \(2n\) ,没有影响。
然后FFT还有一个广为人知的优化:三次变两次优化。
我们看我们原来的FFT代码,两次DFT,一次IDFT,一共三次。我们可以利用复数的一些性质把它变成两次。
原理大概是把原先要卷积的两个多项式一个放到实部,一个放到虚部。举个例子,设两个多项式分别为\(F(x)=\sum_{i=0}^nf_ix^i,G(x)=\sum_{i=0}^mg_ix^i\)。我们构造一个多项式\(H(x)\),使得\(h_i=f_i+\text ig_i\),然后求它的平方,结果的虚部除以\(2\)就是答案。
因为我们有
所以是对的。
上个代码,洛谷神机差不多1.1s。
void fft(cp a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
cp wn={cos(pi/mid),tp*sin(pi/mid)};
for(int j=0;j<n;j+=mid<<1){
cp w={1,0};
for(int k=0;k<mid;k++,w=w*wn){
cp x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;a[j+mid+k]=x-y;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i].i/=2*n;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++)scanf("%lf",&a[i].i);
get(n+m);
fft(a,wl,1);
for(int i=0;i<wl;i++)a[i]=a[i]*a[i];
fft(a,wl,-1);
for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].i+0.5));
}
然而三次变两次有没有局限性呢?有的。在两边系数值域相差太大的时候精度严重掉。原因仍然显然。修正方法也不难,把两个多项式数乘一下,值域相同就行了。别忘了除回去。
快速数论变换(Number Theory Transform,NTT)
实际上我们一般不会用FFT,因为缺点很明显:一堆double,不光跑得慢而且会炸精度。那还有什么方法优化呢?我们发现,数论里有个东西和单位根的性质很类似。这个东西叫原根。(忘记原根定义的去oiwiki翻翻)
我们看看它和单位根有什么类似性质。
- 设\(g\)是素数\(p\)的原根,则\(g,g^2,g^3,\cdots,g^{p-1}\)在\(\mod p\)意义下两两不同。
- \((g^{\frac {p-1}{2n}})^2=g^{\frac {p-1}n}\),而\(\omega_{2n}^2=\omega_n\)。
然后你把原根带进我们需要用的单位根的性质里会发现都成立。就它了。
然而我们观察式子发现我们必须要保证\(n|(p-1)\)。所以这个对\(p\)的取值还有要求。我们一般选\(998244353\),它的原根是\(3\),而且\(998244353=7\times 17\times 2^{23}+1\)。够用了。
所以直接把上面的代码所有的单位根换成原根就行了。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const int mod=998244353,g=3,invg=332748118;
int a[2100010],b[2100010];
int n,m,inv,wl=1,r[2100010];
void get(int n){
while(n>=wl)wl<<=1;
for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
int qpow(int a,int b){
int ans=1;
while(b){
if(b&1)ans=1ll*a*ans%mod;
a=1ll*a*a%mod;
b>>=1;
}
return ans;
}
void ntt(int a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
int wn=qpow(tp==1?g:invg,(mod-1)/(mid<<1));
for(int j=0;j<n;j+=mid<<1){
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*wn%mod){
int x=a[j+k],y=1ll*w*a[j+mid+k]%mod;
a[j+k]=(x+y)%mod;a[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%d",&a[i]);
for(int i=0;i<=m;i++)scanf("%d",&b[i]);
get(n+m);inv=qpow(wl,mod-2);
ntt(a,wl,1);ntt(b,wl,1);
for(int i=0;i<wl;i++)a[i]=1ll*a[i]*b[i]%mod;
ntt(a,wl,-1);
for(int i=0;i<=n+m;i++)printf("%d ",a[i]);
}
当然还有一些其他的NTT模数,比如\(1004535809=479\times 2^{21}+1\),原根为\(3\)。还有一个是\(469762049\),原根也是\(3\)。
任意模数NTT(MTT)
MTT,也就是任意模数NTT(其实也不用NTT,用FFT)。
FFT可以处理任意模数,但是值域较大的时候不光会爆longlong还会丢精。NTT可以处理大值域,但是要求模数是\(2^n+1\)的形式。现在要求值域较大的任意模数乘法。
首先你当然可以选三个NTT模数然后CRT合并。但是这个九次NTT的做法常数要多大有多大,所以一般没人用。
然后是我们的主题,MTT,也就是拆系数FFT。
具体地讲,我们把两个多项式的系数拆成两部分分别处理(我以\(2^{15}\)为界)。设我们拆的单位是\(M\),我们将系数拆成了\(kM+b\)两部分,得到了\(F_1,F_2,G_1,G_2\)四个多项式。则我们的答案就是:
如果我们暴力算四个多项式乘法就是12次FFT,更慢了。
如果我们稍微动点脑子,分别将四个多项式DFT然后乘法之后IDFT回来,这是8次FFT,还是很慢。
我们将DFT和IDFT的部分分开优化。DFT的部分,我们可以使用三次变两次优化,将两次DFT变成一次。具体的,如果我们要对\(F(x),G(x)\)两个多项式DFT,那么我们设两个多项式:
我们直接对\(P\)做DFT,然后就可以通过共轭求出\(Q\)的DFT,解一个方程组就得到了\(F,G\)的点值表示。
然后是IDFT的部分。我们之前得到了\(F_0,F_1\)的点值表示,我们把它们左右同乘一个\(P(x)=G_0(x)+\text iG_1(x)\),就变成了
将它转回系数表示之后提出四个实部虚部,就得到了我们想要的四个多项式卷积。加和即可。
上个代码,写写注释。记得开long double不然只有50分。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const long double pi=acos(-1);
const int sq=(1<<15)-1;
struct cp{
long double r,i;
cp(long double a=0,long double b=0){r=a;i=b;}
cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
}a[300010],b[300010],p[300010],q[300010];
int n,m,mod,wl=1,r[300010],ans[300010];
void get(int n){
while(n>=wl)wl<<=1;
for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
void fft(cp a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
cp wn={cos(pi/mid),tp*sin(pi/mid)};
for(int j=0;j<n;j+=mid<<1){
cp w={1,0};
for(int k=0;k<mid;k++,w=w*wn){
cp x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;a[j+mid+k]=x-y;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i].r/=n,a[i].i/=n;
}
int main(){
scanf("%d%d%d",&n,&m,&mod);
for(int i=0;i<=n;i++){
int x;scanf("%d",&x);
a[i]=cp (x&sq,x>>15);//拆分系数
}
for(int i=0;i<=m;i++){
int x;scanf("%d",&x);
b[i]=cp (x&sq,x>>15);
}
get(n+m);
fft(a,wl,1);fft(b,wl,1);
for(int i=0;i<wl;i++){
int ret=(wl-i)&(wl-1);/*解释一下这个东西
首先我们点值表示的每个下标是当前这个单位根的取值
然后这个相当于0不反转 其他数i翻转成wl-i 即第i个单位根共轭处的取值 所以是对的*/
p[i]=(cp){0.5*(a[i].r+a[ret].r),0.5*(a[i].i-a[ret].i)}*b[i];//这是解方程之后的结果
q[i]=(cp){0.5*(a[i].i+a[ret].i),0.5*(a[ret].r-a[i].r)}*b[i];
}
fft(p,wl,-1);fft(q,wl,-1);
for(int i=0;i<wl;i++){
long long p1=p[i].r+0.5,q1=p[i].i+0.5,x=q[i].r+0.5,y=q[i].i+0.5;
ans[i]=(p1%mod+((q1+x)%mod<<15)+((y%mod)<<30))%mod;//按照公式代入即可
}
for(int i=0;i<=n+m;i++)printf("%d ",ans[i]);
}