多项式乘法(FFT,NTT,MTT)
首先从多项式的概念说起。
多项式,就是形如
然后是重要的多项式卷积:定义多项式
其实就是把两个多项式乘起来然后起了个很高级的名字。然后我们拿定义模拟乘法加法运算来计算卷积显然是
于是我们有方法来在
- 复数(请翻阅数学必修二课本)
- 多项式的系数表示和点值表示(这个就是字面意思)
- 单位根(这个得说说)
考虑
显然这个东西有通项公式
然后是单位根的一些重要性质。
。显然。或者你把单位根的通项搞成 随便搞搞也行。 。 是共轭。这个也显然,复平面上关于 轴对称一下。 。复平面上逆时针转半圈就是。
其实都是废话。接下来进入正题。
快速傅里叶变换(Fast Fourier Transform,FFT)
Fast Fast TLE
我们考虑到系数表示的极限也就是
先人们为我们造好了轮子:
我们还是看原来式子
然后我们单位根的用处就来了,有这样一个柿子,叫单位根反演:
证明一下:就是分类讨论。
显然是个
然后我们发现后面两个东西不就可以
然后定义
就是这个过程。当然它是个线性变换,就是把原来多项式的系数当成一个行向量,乘一个单位根组成的范德蒙德矩阵就行了。然后这个模数继续不管,先说明FFT是怎么做到
我们首先设要变换的多项式
设
那么原先的多项式可以这样表示:
然后我们分治递归向下处理这个式子就变成了
递归版的我随便扒了一份,没自己写。所以没有注释。想看随便看看,不想看可以跳过。
void fft(int n, complex<double>* buffer, int offset, int step, complex<double>* epsilon)
{
if(n == 1) return;
int m = n >> 1;
fft(m, buffer, offset, step << 1, epsilon);
fft(m, buffer, offset + step, step << 1, epsilon);
for(int k = 0; k != m; ++k)
{
int pos = 2 * step * k;
temp[k] = buffer[pos + offset] + epsilon[k * step] * buffer[pos + offset + step];
temp[k + m] = buffer[pos + offset] - epsilon[k * step] * buffer[pos + offset + step];
}
for(int i = 0; i != n; ++i)
buffer[i * step + offset] = temp[i];
}
然后我们发现递归太慢了而且容易爆栈,所以我们需要一个非递归的写法。
我们观察一下递归的时候每个系数的位置情况(以
第一次:
第二次:
第三次:
第四次:
我们列个表看一下它们的二进制位。
发现最终会变成反转之后的数。所以我们可以预处理每个数二进制反转之后的结果,FFT开始前交换一下。
考虑如何预处理这个东西。设这个东西叫
然后是迭代版FFT的另一个重要操作:蝴蝶操作。
我们每次合并结果时,为了避免数组覆盖原值导致错误,我们用临时变量存储原值来进行操作(其实就两句代码,代码一眼出)。
当然还有单位根的问题。每次现算单位根太慢了,有没有什么快的方法?
当然你可以预处理。但是我们有更优秀的方法:每次只算一遍单位根,然后迭代出想要的单位根。具体的仍然看代码。
接下来这份代码在洛谷的板子里跑了1.44s。比较优秀了。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const double pi=acos(-1);
struct cp{
double r,i;
cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
cp conj(cp s){return cp{s.r,-s.i};}
}a[2100010],b[2100010];//加减乘 共轭的复数类 够用了 还有数组要开两倍
int n,m,wl=1,r[2100010];
void get(int n){
while(n>=wl)wl<<=1;
for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));//预处理反转操作结果
}
void fft(cp a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){//枚举区间中点
cp wn={cos(pi/mid),tp*sin(pi/mid)};//一个单位根
for(int j=0;j<n;j+=mid<<1){//当前到哪个位置
cp w={1,0};
for(int k=0;k<mid;k++,w=w*wn){//左半部分 每次迭代出单位根
cp x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;a[j+mid+k]=x-y;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i].r/=n;//idft最后除以n
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++)scanf("%lf",&b[i].r);
get(n+m);
fft(a,wl,1);fft(b,wl,1);
for(int i=0;i<wl;i++)a[i]=a[i]*b[i];
fft(a,wl,-1);
for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].r+0.5));//四舍五入
}
这时候我们可以解释前面
然后FFT还有一个广为人知的优化:三次变两次优化。
我们看我们原来的FFT代码,两次DFT,一次IDFT,一共三次。我们可以利用复数的一些性质把它变成两次。
原理大概是把原先要卷积的两个多项式一个放到实部,一个放到虚部。举个例子,设两个多项式分别为
因为我们有
所以是对的。
上个代码,洛谷神机差不多1.1s。
void fft(cp a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
cp wn={cos(pi/mid),tp*sin(pi/mid)};
for(int j=0;j<n;j+=mid<<1){
cp w={1,0};
for(int k=0;k<mid;k++,w=w*wn){
cp x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;a[j+mid+k]=x-y;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i].i/=2*n;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++)scanf("%lf",&a[i].i);
get(n+m);
fft(a,wl,1);
for(int i=0;i<wl;i++)a[i]=a[i]*a[i];
fft(a,wl,-1);
for(int i=0;i<=n+m;i++)printf("%d ",(int)(a[i].i+0.5));
}
然而三次变两次有没有局限性呢?有的。在两边系数值域相差太大的时候精度严重掉。原因仍然显然。修正方法也不难,把两个多项式数乘一下,值域相同就行了。别忘了除回去。
快速数论变换(Number Theory Transform,NTT)
实际上我们一般不会用FFT,因为缺点很明显:一堆double,不光跑得慢而且会炸精度。那还有什么方法优化呢?我们发现,数论里有个东西和单位根的性质很类似。这个东西叫原根。(忘记原根定义的去oiwiki翻翻)
我们看看它和单位根有什么类似性质。
- 设
是素数 的原根,则 在 意义下两两不同。 ,而 。
然后你把原根带进我们需要用的单位根的性质里会发现都成立。就它了。
然而我们观察式子发现我们必须要保证
所以直接把上面的代码所有的单位根换成原根就行了。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const int mod=998244353,g=3,invg=332748118;
int a[2100010],b[2100010];
int n,m,inv,wl=1,r[2100010];
void get(int n){
while(n>=wl)wl<<=1;
for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
int qpow(int a,int b){
int ans=1;
while(b){
if(b&1)ans=1ll*a*ans%mod;
a=1ll*a*a%mod;
b>>=1;
}
return ans;
}
void ntt(int a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
int wn=qpow(tp==1?g:invg,(mod-1)/(mid<<1));
for(int j=0;j<n;j+=mid<<1){
int w=1;
for(int k=0;k<mid;k++,w=1ll*w*wn%mod){
int x=a[j+k],y=1ll*w*a[j+mid+k]%mod;
a[j+k]=(x+y)%mod;a[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%d",&a[i]);
for(int i=0;i<=m;i++)scanf("%d",&b[i]);
get(n+m);inv=qpow(wl,mod-2);
ntt(a,wl,1);ntt(b,wl,1);
for(int i=0;i<wl;i++)a[i]=1ll*a[i]*b[i]%mod;
ntt(a,wl,-1);
for(int i=0;i<=n+m;i++)printf("%d ",a[i]);
}
当然还有一些其他的NTT模数,比如
任意模数NTT(MTT)
MTT,也就是任意模数NTT(其实也不用NTT,用FFT)。
FFT可以处理任意模数,但是值域较大的时候不光会爆longlong还会丢精。NTT可以处理大值域,但是要求模数是
首先你当然可以选三个NTT模数然后CRT合并。但是这个九次NTT的做法常数要多大有多大,所以一般没人用。
然后是我们的主题,MTT,也就是拆系数FFT。
具体地讲,我们把两个多项式的系数拆成两部分分别处理(我以
如果我们暴力算四个多项式乘法就是12次FFT,更慢了。
如果我们稍微动点脑子,分别将四个多项式DFT然后乘法之后IDFT回来,这是8次FFT,还是很慢。
我们将DFT和IDFT的部分分开优化。DFT的部分,我们可以使用三次变两次优化,将两次DFT变成一次。具体的,如果我们要对
我们直接对
然后是IDFT的部分。我们之前得到了
将它转回系数表示之后提出四个实部虚部,就得到了我们想要的四个多项式卷积。加和即可。
上个代码,写写注释。记得开long double不然只有50分。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
const long double pi=acos(-1);
const int sq=(1<<15)-1;
struct cp{
long double r,i;
cp(long double a=0,long double b=0){r=a;i=b;}
cp operator+(const cp &s)const{return cp{r+s.r,i+s.i};}
cp operator-(const cp &s)const{return cp{r-s.r,i-s.i};}
cp operator*(const cp &s)const{return cp{r*s.r-i*s.i,r*s.i+i*s.r};}
}a[300010],b[300010],p[300010],q[300010];
int n,m,mod,wl=1,r[300010],ans[300010];
void get(int n){
while(n>=wl)wl<<=1;
for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
void fft(cp a[],int n,int tp){
for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
cp wn={cos(pi/mid),tp*sin(pi/mid)};
for(int j=0;j<n;j+=mid<<1){
cp w={1,0};
for(int k=0;k<mid;k++,w=w*wn){
cp x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;a[j+mid+k]=x-y;
}
}
}
if(tp^1)for(int i=0;i<n;i++)a[i].r/=n,a[i].i/=n;
}
int main(){
scanf("%d%d%d",&n,&m,&mod);
for(int i=0;i<=n;i++){
int x;scanf("%d",&x);
a[i]=cp (x&sq,x>>15);//拆分系数
}
for(int i=0;i<=m;i++){
int x;scanf("%d",&x);
b[i]=cp (x&sq,x>>15);
}
get(n+m);
fft(a,wl,1);fft(b,wl,1);
for(int i=0;i<wl;i++){
int ret=(wl-i)&(wl-1);/*解释一下这个东西
首先我们点值表示的每个下标是当前这个单位根的取值
然后这个相当于0不反转 其他数i翻转成wl-i 即第i个单位根共轭处的取值 所以是对的*/
p[i]=(cp){0.5*(a[i].r+a[ret].r),0.5*(a[i].i-a[ret].i)}*b[i];//这是解方程之后的结果
q[i]=(cp){0.5*(a[i].i+a[ret].i),0.5*(a[ret].r-a[i].r)}*b[i];
}
fft(p,wl,-1);fft(q,wl,-1);
for(int i=0;i<wl;i++){
long long p1=p[i].r+0.5,q1=p[i].i+0.5,x=q[i].r+0.5,y=q[i].i+0.5;
ans[i]=(p1%mod+((q1+x)%mod<<15)+((y%mod)<<30))%mod;//按照公式代入即可
}
for(int i=0;i<=n+m;i++)printf("%d ",ans[i]);
}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· Manus爆火,是硬核还是营销?
· 一文读懂知识蒸馏
· 终于写完轮子一部分:tcp代理 了,记录一下