再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
写在前面
为了不使篇幅过长,预计将把基于论文的学习笔记分为三部分:
- DFT,IDFT,FFT的定义,实现与证明:快速傅里叶变换(FFT)学习笔记(其一)
- NTT的实现与证明:快速傅里叶变换(FFT)学习笔记(其二)
- 任意模数NTT与FFT的优化技巧
一些约定
- \([p(x)]=\begin{cases}1,p(x)为真 \\ 0,p(x)为假 \end{cases}\)
- 本文中序列的下标从0开始
- 若\(s\)是一个序列,\(|s|\)表示\(s\)的长度
- 若大写字母如\(F(x)\)表示一个多项式,那么对应的小写字母如\(f\)表示多项式的每一项系数,即\(F(x)=\sum_{i=0}^{n-1} f_ix^i\)
循环卷积
DFT卷积的本质
考虑在(其一)中提到的卷积的定义式。
我们一般做FFT时忽略了式子中的\(\bmod\),其实它是在\(\bmod 2^q\)的意义下的循环卷积,只是因为\(|a|,|b|,|c|<2^q\),所以取不取模都没什么影响。
如果序列长度\(n\)是2的整数次幂,那么直接做就可以了。
如果序列长度\(n\)不是2的整数次幂考虑暴力的做法:先做一次普通FFT,再把\(c_{k+n}\)加到\(c_k\)上。但是这样在做多次FFT时就必须一次一次做,比如多项式快速幂。下面给出了一种在\(O(n \log n)\)的时间内实现任意长度循环卷积的算法:Bluestein’s Algorithm
Bluestein’s Algorithm
注:原论文的推导可能有误
考虑DFT的式子
不妨设
\(x_j=a_j \omega_n^{\frac{j^2}{2}}=a_j(\cos\frac{j^2\pi}{n}+ \text{i}\sin{\frac{j^2\pi}{n}})\)
\(y_j=\omega_n^{-\frac{j^2}{2}}= \cos \frac{\pi j^2}{n}-\text{i}\sin \frac{\pi j^2}{n}\)
那么\(a_i'=\omega_n^{\frac{j^2}{2}}\sum_{j=0}^{n-1} x_j y_{i-j}\)
这已经很类似卷积的形式了,但是注意到\(j\)的上界是\(n-1\)而不是\(i\),\(j-i\)可能为负数。那么我们把\(y\)数组的长度扩大到\(2n\),定义:
\(y_j=\omega_n^{-\frac{(j-n)^2}{2}}= \cos \frac{\pi (j-n)^2}{n}-\text{i}\sin \frac{\pi (j-n)^2}{n}\).
这样\(j<n\)的时候就对应了\(j-i\)为负数的情形,\(j\geq n\)就对应了\(j-i\)为正的情形。然后对\(x\)和\(y\)用一般的FFT,最后的答案存储在\(i+n\)的位置上,也就是说真正的\(a'_i\)实际上对应了乘积结果的\((x \cdot y)_{i+n}\)
这样,我们就只做了3次FFT就求出了任意长度循环DFT。逆变换同理,只是换成共轭复数。注意到在上述的推导中我们没有用到单位根\(\omega\)的任何性质,因此这里的\(\omega\)可以换成任意复数\(z\),这样的变换称为Chirp Z-Transform,CZT.可见,CZT实际上是DFT的广义形式。
代码实现:
//com是手写复数类,省略
void fft(com *x,int *rev,int n,int type){
//为节约篇幅,fft部分省略,x为系数序列,rev为反转数组,n为长度,type=1表示DFT,type=-1表示IDFT
}
void bluestein(com *a,int n,int type){
//a为系数序列,n为长度,type=1表示DFT,type=-1表示IDFT
static com x[maxn*4+5],y[maxn*4+5];
static int rev[maxn*4+5];
memset(x,0,sizeof(x));
memset(y,0,sizeof(y));
//FFT前的预处理
int N=1,L=0;
while(N<n*4){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
//x[i],y[i]的定义见上式
for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i];
for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n));
fft(x,rev,N,1);
fft(y,rev,N,1);
for(int i=0;i<N;i++) x[i]*=y[i];
fft(x,rev,N,-1);
for(int i=0;i<n;i++){
a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//记得乘上常数
if(type==-1) a[i]/=n;//一定记得除以n,因为做一次Bluestein相当于一次FFT,IFFT最后要除n,这里也要除n
}
}
例题
[POJ 2821]TN's Kindom III(任意长度循环卷积的Bluestein算法)
分治FFT
一般我们用FFT的时候,序列的所有元素都已知。但是,如果序列本身是根据卷积定义的,就无法直接套FFT
举一个最简单的例子\(f_i =\sum_{j=1}^i f_{i-j}g_j\).其中\(g\)给定,求\(f\). 由于我们卷积的时后后面的数基于前面的数,无法快速计算,时间复杂度退化到\(O(n^2)\). (虽然这个式子可以用(其四)中将会提到的多项式求逆解决,但是分治FFT更通用,可以处理很复杂的式子)
考虑分治: 设当前分治区间为\([l,r]\),假设我们求出了\([l,mid]\)的答案,那么可以求出这些点对\([mid+1,r]\)的影响。那么右半边的点\(x \in [mid+1,r]\)得到的贡献是\(\Delta_x=\sum_{i=l}^{mid} f_i g_{x-i}\).只需要把下标偏移一下(如\([l,mid]\)偏移成\([0,mid-l]\),就是一个卷积的形式,可以运用FFT或NTT计算,计算完之后,把答案累加到数组上.
伪代码如下:
poly f,g;//上述的f,g
procedure calc(L,mid,R){
for i in [L,mid] : a[i-L] <- f[i]//下标偏移
for i in [1,R-L] : b[i-1] <- g[i]
a <- mul(a,b);//fft或ntt做多项式乘法
for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加贡献
}
procedure solve(l,mid){
if(l==r) return;
mid <- (l+r)/2
solve(l,mid);
calc(l,mid,r);
solve(mid+1,r)
}
时间复杂度分析:
\(T(n)=2T(\frac{n}{2})+n \log_2n\), 总复杂度\(\Theta(n \log^2n)\)
下面是基于NTT的模板代码(Luogu 4721)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 300000
#define G 3
#define invG 332748118
#define inv2 499122177
#define mod 998244353
using namespace std;
typedef long long ll;
inline ll fast_pow(ll x,ll k){
ll ans=1;
while(k){
if(k&1) ans=ans*x%mod;
x=x*x%mod;
k>>=1;
}
return ans;
}
inline ll inv(ll x){
return fast_pow(x,mod-2);
}
void NTT(ll *x,int n,int type){
static int rev[maxn+5];
int tn=1;
int k=0;
while(tn<n){
tn*=2;
k++;
}
for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
for(int i=0;i<n;i++){
if(i<rev[i]) swap(x[i],x[rev[i]]);
}
for(int len=1;len<n;len*=2){
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
for(int l=0;l<n;l+=sz){
int r=l+len-1;
ll gnk=1;
for(int i=l;i<=r;i++){
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1){
int invsz=inv(n);
for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod;
}
}
void mul(ll *a,ll *b,ll *ans,int sz){
NTT(a,sz,1);
NTT(b,sz,1);
for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod;
NTT(ans,sz,-1);
}
void cdq_divide(ll *f,ll *g,int l,int r){
static ll tmpa[maxn+5],tmpb[maxn+5];
if(l==r) return;
int mid=(l+r)>>1;
cdq_divide(f,g,l,mid);
int tn=1,k=0;
while(tn<r-l){
k++;
tn*=2;
}
for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0;
for(int i=l;i<=mid;i++) tmpa[i-l]=f[i];
for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i];
mul(tmpa,tmpb,tmpa,tn);
for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod;
cdq_divide(f,g,mid+1,r);
}
int n;
ll f[maxn+5],g[maxn+5];
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++) scanf("%lld",&g[i]);
f[0]=1;
cdq_divide(f,g,0,n-1);
for(int i=0;i<n;i++) printf("%lld ",f[i]);
}
容易发现,许多dp方程都有分治FFT的形式。对于此类dp方程,我们可以用分治FFT将转移复杂度由\(O(n^2)\)降到\(O(n \log^2 n)\)
例题
[Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)
FFT的弱常数优化
下面介绍一些优化FFT的常数的技巧。虽然这些技巧都只是对FFT的一些小优化,但是在某些题目中优化效果极其明显。
复杂算式中减少FFT次数
如果我们要计算一个复杂的多项式,如\(A(x)=B(x)C(x)+D(x)E(x)\)
最简单的方法是分别计算\(B(x)C(x)\)和\(D(x)E(x)\),这样需要做6次FFT. 但是如果先对\(B,C,D,E\)做DFT,然后直接用点值表达式计算\(a_i=b_ic_i+d_ie_i\),再把\(a\)IDFT回去。这样只需要做5次FFT,且多项式越复杂,这样的常数就越优秀。
例题
[BZOJ 3771] Triple(FFT+容斥原理+生成函数)
利用循环卷积
考虑对于两个长度为\(n\)的序列\(a,b\),计算它们的卷积\(c\)的第\(0.5n\)项到第\(1.5n\)项。传统的方法是补0扩充到\(2n\)的序列。但是因为FFT求得实际上是我们已经提到过的循环卷积,所以如果只补0到\(1.5n\)(上取整),对第\(0.5n\)项到第\(1.5n\)项无影响
在基于牛顿迭代的算法中,能起到较明显的优化作用。会在(其四)中详细介绍这些算法。
小范围暴力
由于FFT的常数较大。在数据范围较小的时候甚至不如\(O(n^2)\)的暴力卷积的优秀。因此在做多次FFT和分治FFT的时候,如果当前的序列长度较小,可以采用暴力算法。
例题
[BZOJ 3509] [CodeChef] COUNTARI (FFT+分块)
快速幂乘法次数的优化
这个东西实际上比较鸡肋。因为多项式快速幂可以通过多项式\(\ln\)和\(\exp\)优化到\(O(n \log n)\).但是为了应对考场上时间不够的情况,我们来考虑如何通过简单的实现来减少\(O(n \log^2n)\)的倍增快速幂的复杂度。
倍增法的思路是根据前面算过的乘积快速算出当前的乘积,如\(1 \to 2 \to 4 \to 8\).最坏情况下需要\(2 \log_2n+C\)次乘法。但这并不是下界。我们定义additional chain为一条链,最开始是1,后一个数减前一个数的差是链上这个是前面的某一个数。例如\(1 \to 2 \to 4 \to 6\).\(6-4=2\)在前面出现过,\(4-2=2\)在前面出现过。那么根据这条additional chain计算6次幂的时候,可以从1次幂出发,用1次幂乘1次幂得到2次幂,再乘2次幂得到4次幂,再乘2次幂得到6次幂。
很可惜,对于数\(k\)求出得到\(k\)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基于BFS。每次我们对于队头的数\(x\),枚举它对应的additional chain中的数\(y\),如果\(x+y\)还没有访问过那么将其入队,并将\(x\)对应的链后面接上\(x+y\). 这个预处理是\(O(k)\)的,且对快速幂的常数优化很显著。
如果\(k\)很大,比如\(10^{10000}\),可以采用十进制快速幂。但是用Method of Four Russians(俗称四毛子算法),可以将乘法次数减少到\(\log_2n+O(\frac{\log n}{\log \log n})\).具体方法见2017年国家集训队论文《非常规大小分块算法初探》
FFT的强常数优化
FFT的强常数优化一般是通过减少FFT次数来实现的
在这一节中,我们记\(DFT(A(x))\)表示多项式\(A(x)\)(或序列)做DFT之后的结果,\(IDFT(A(x))\)同理
我们现在考虑最常见的一个模型:给出两个长度为\(n+1\)和\(m+1\)的多项式\(A(x),B(x)\),我们要计算他们的线性卷积。假设长度已经补齐为第一个大于\(n+m+1\)的2的整数幂\(L\)。
显然直接搞需要3次长度为\(L\)的FFT。毒瘤的Vladimir Smykalov在cf上最先给出了这个问题的优化算法。
DFT的合并
DFT的合并是指,对于两个序列\(a\),\(b\),我们只通过一次FFT就求出\(DFT(a),DFT(b)\)
不妨设:
接下来我们开始推导公式。注意为了简洁,我们记\(X=\frac{2 \pi jk}{2L}\),\(\text{conj}(z)\)表示\(z\)的共轭复数
也就是说,只要一次DFT算出\(DFT(p)\),就可以把序列反转再取共轭复数得到\(DFT(q)\).
由于DFT是线性变换,
其中\(j\)为\(k\)翻转后的数,即\(j=\begin{cases}0,k=0 \\ L-k ,k>0 \end{cases}\)
又由\((4.1),(4.2)\)式
这样我们就可以从\(q'\)推出\(a',b'\),也就是说一次DFT就能得到\(a'\)和\(b'\)了.
我们一共做了2次长度为\(L\)的FFT.
代码(UOJ#34):
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("%lf + %lf i ",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
void mul(ll *a,ll *b,ll *c,int n){
static com p[maxn+5],r[maxn+5];
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//预处理单位根
for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i]
fft(p,n);
for(int i=0;i<n;i++){
int j=(i>0?(n-i):0);//0的位置需要特判一下
com q=p[j];
r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子
}
fft(r,n);//这里是用了第一篇中提到的反转技巧
for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5;
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<n+m+1){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}
IDFT的合并
IDFT的合并是指,对于两个序列\(a\),\(b\),我们只通过一次FFT就求出\(IDFT(a),IDFT(b)\)
IDFT的合并非常简单。
设\(r(x)=a(x)+\text{i}b(x)\)
由于IDFT是线性变换
\(IDFT(r(x))=IDFT(a(x))+\text{i}IDFT(b(x))\)
又因为\(a(x)\)和\(b(x)\)都是实数序列,那么\(IDFT(r(x))\)的实部就是\(IDFT(a(x))\),虚部就是\(IDFT(b(x))\)
形如\((A+B)(C+D)\)的卷积的优化
在这一节中我们讨论\((A(x)+B(x))(C(x)+D(x))\)形式的卷积的优化.
一般的做法是对\(A,B,C,D\)都做一次DFT,然后按照这个式子直接计算,最后再IDFT回来。需要5次FFT.
而根据上面的合并技巧,先把\(A(x),B(x)\)合并DFT,\(C(x),D(x)\)合并DFT得到点值表达式.
由于\((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x)\)
我们可以直接把点值表达式相乘得到这4个多项式。对于这4个多项式,分成2组合并做IDFT即可。
总共需要4次FFT.
大致代码如下:
void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){
static com p[maxn+5],q[maxn+5];
static com r[maxn+5],s[maxn+5];
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
for(int i=0;i<n;i++){
p[i]=com(a[i],b[i]);//打包A,B
q[i]=com(c[i],d[i]);//打包C,D
}
fft(p,n);
fft(q,n);
for(int i=0;i<n;i++){
int j=(i==0?0:n-i);
//得到DFT(A),DFT(B),DFT(C),DFT(D)
com da=(p[i]+p[j].conj())*0.5;
com db=(p[i]-p[j].conj())*com(0,-0.5);
com dc=(q[i]+q[j].conj())*0.5;
com dd=(q[i]-q[j].conj())*com(0,-0.5);
r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
}
fft(r,n);
fft(s,n);
for(int i=0;i<n;i++){
ll ac,ad,bc,bd;
ac=(ll)(r[i].real/n+0.5);
ad=(ll)(r[i].imag/n+0.5);
bc=(ll)(s[i].real/n+0.5);
bd=(ll)(s[i].imag/n+0.5);
ans[i]=ac+ad+bc+bd;
}
}
卷积的终极优化
上述优化中我们只用到了DFT的思想。现在我们利用FFT的思想继续优化
同样拆分奇偶项,\(A(x)=A_0(x^2)+xA_1(x^2)\)
我们只需要知道上式中\(x^0,x^1,x^2\)的系数
发现\(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2)\)是奇数项的系数,\(A_0(x^2)B_0(x^2)\)和\(A_1(x^2)B_1(x^2)\)是偶数项的系数,而偶数项的两个东西都可以看成一个关于\(x^2\)的多项式。
我们先优化DFT的过程,观察\((4.6)\)式的乘积形式\((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\).
我们发现,这个形式和上一节的\((A+B)(C+D)\)很像,可以类似地优化。
令\(p_k={a_0}_k+\text{i}{a_1}_k,q_k={b_0}_k+\text{i}{b_1}_k\)
然后合并IDFT,再设两个辅助多项式
(注意我们把\(x^2\)换元成\(x\),做DFT的时候要乘上单位根)
那么我们只需要计算出\(IDFT(G(x))\)和\(IDFT(F(x))\)
设\(R(x)=G(x)+\mathrm{i} F(x)\)
那么因为IDFT是线性变换,\(IDFT(R(x))=IDFT(G(x))+\mathrm{i} IDFT(F(x))\)
(IDFT的线性性这里不做证明,容易发现两个点值表达式相加再IDFT回来,显然系数也会相加)
显然这两个多项式IDFT的结果是实数。故我们只要求出\(IDFT(R(x))\),每一项系数的实部就是偶数项系数\(G(x)\),虚部就是奇数项系数\(F(x)\)
我们再考虑把合并DFT弄进去,即式\((4.3)(4.4)(4.5)\)
接下来我们尝试用\(DFT(p_k),DFT(q_k)\)来表示\(R(x)=G(x)+\text{i}F(x)\),为了推导简洁,我们省略\(DFT\)不写
那么
和上一节的\((A+B)(C+D)\)不同,我们只用了3次长度为\(L/2\)的FFT,就求出了答案,这是由于FFT本身的性质。因为长度缩减了一半,我们不妨称它为\(1.5\)次FFT.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("%lf + %lf i ",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
void mul(ll *a,ll *b,ll *c,int n){
static com p[maxn+5],q[maxn+5],r[maxn+5];
for(int i=0;i<n;i++){//合并做DFT
if(i%2==1){
p[i/2].imag=a[i];
q[i/2].imag=b[i];
}else{
p[i/2].real=a[i];
q[i/2].real=b[i];
}
}
n/=2;
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
fft(q,n);
fft(p,n);
for(int i=0;i<n;i++){
int j=(i>0?(n-i):0);
r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25;
}
fft(r,n);
for(int i=0;i<n;i++){
c[i*2]=r[i].real/n+0.5;
c[i*2+1]=r[i].imag/n+0.5;
}
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<=n+m+1){
L++;
N*=2;
}
for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意这里的rev数组是对N/2做的,L要-1
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld\n",c[i]);
}
任意模数NTT
三模数NTT
这是任意模数NTT的算法中最好理解的一种,它基于中国剩余定理。
定理5.1 若\(m_1,m_2 ,\dots m_n\)两两互质,则对于\(\forall a_1,a_2 \dots a_n\)同余方程组
\[\begin{cases} x \equiv a_1 (\bmod m_1) \\ x \equiv a_2 (\bmod m_2) \\ \dots \\ x \equiv a_n (\bmod m_n)\end{cases} \]有整数解解,且可以用如下方式构造解
- 设\(M=\prod_{i=1}^n m_i,M_i=\frac{M}{m_i}\)
- 设\(M_i^{-1}\)为模\(m_i\)意义下\(M_i\)的逆元
- 则该方程组在模\(M\)意义下的唯一解为\(x=\sum_{i=1}^n a_iM_iM_i^{-1}\) ,方程组的通解可以表示为\(x+kM(k \in \mathbb{Z})\)
这就是著名的中国剩余定理(Chinese Reminder Theorem,CRT)
证明:
对于\(k \neq i\),\(a_iM_iM_i^{-1} \bmod m_k=0\), 而根据逆元的定义,\(a_iM_iM_i^{-1} \bmod m_i =a_i\). 再代入到\(\sum_{i=1}^n a_iM_iM_i^{-1}\),原方程组成立。
回到任意模数NTT问题
模\(M\)意义下长度为\(n\)的序列做卷积,最大值可以到\(n^2M\).一般的题目中\(n \leq 10^5,M\leq 10^{9}\),那么结果会到\(10^{23}\)级别。用long double
等存储会丢失精度。那么我们可以选三个乘起来大于\(10^{23}\)的NTT模数998244353,1004535809,469762049(选这三个模数的好处是他们的原根都是3,所以NTT部分写起来比较简洁)。然后分别在这三个模数的意义下做卷积。最后考虑把答案合并,我们只考虑某一位上的值\(ans\),容易写出:
显然\(m_1,m_2,m_3\)互质,那么我们可以利用中国剩余定理直接合并。但是,直接合并把三个模数乘起来的时候会超出long long
的范围。注意到两个模数相乘还是在long long
范围内的,可以两两合并,具体方法如下,
记\(inv(a,m)\)表示\(a\)在模\(m\)下的逆元.根据CRT合并\((5.2)(5.3)\)有:
不妨设\(ans=km_1m_2+r\),根据\(5.4\)有
\(ans=km_1 m_2+r=q m_3+a_3 \tag{5.6}\),
在模 \(m_3\) 意义下有
\(km_1 m_2+r \equiv a_3 (\bmod m_3) \tag{5.7}\)
因此\(k=(a_3-r_2)inv(m_1m_2,m_3) (\bmod m_3)\),不妨设\(k=dm_3+e\),代入\(5.6\)得
由于\(m_1m_2m_3>ans\),所以\(d=0\),也就是说,\(ans=em_1m_2+r\),其中\(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3)\)
const ll mm=m1*m2;
inline ll inv(ll a,ll m);
ll mul(ll a,ll b,ll m);//要用按位乘防止溢出
ll CRT(ll a1,ll a2,ll a3){
ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
return ((e%C)*(mm%C)%C+r%C)%C;
}
完整代码(LuoguP4245 【模板】任意模数NTT)
#include<iostream>
#include<cstdio>
#include<cstring>
#define m1 998244353ll
#define m2 1004535809ll
#define m3 469762049ll
#define G 3
#define maxn 1048576
using namespace std;
typedef long long ll;
const ll mm=m1*m2;
ll C;
ll fast_pow(ll x,ll k,ll m){
ll ans=1;
while(k){
if(k&1) ans=ans*x%m;
x=x*x%m;
k>>=1;
}
return ans;
}
inline ll inv(ll a,ll m){
return fast_pow(a%m,m-2,m); //一定要取模m
}
ll mul(ll a,ll b,ll m){
ll ans=0;
while(b){
if(b&1) ans=(ans+a)%m;
a=(a+a)%m;
b>>=1;
}
return ans;
}
ll CRT(ll a1,ll a2,ll a3){
//[Warning]You are not expected to understand this.
ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
return ((e%C)*(mm%C)%C+r%C)%C;
}
int n,m,N,L;
int rev[maxn+5];
void NTT(ll *x,int n,int type,ll mod){
ll invG=inv(G,mod);
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod);
for(int l=0;l<n;l+=sz){
int r=l+len-1;
ll gnk=1;
for(int i=l;i<=r;i++){
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1){
ll invn=inv(n,mod);
for(int i=0;i<n;i++) x[i]=x[i]*invn%mod;
}
}
void fmul(ll *a,ll *b,ll *ans,int n,ll mod){
static ll ta[maxn+5],tb[maxn+5];
for(int i=0;i<n;i++) ta[i]=a[i];
for(int i=0;i<n;i++) tb[i]=b[i];
NTT(ta,n,1,mod);
if(a!=b) NTT(tb,n,1,mod);
for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod;
NTT(ans,n,-1,mod);
}
ll a[maxn+5],b[maxn+5],c[3][maxn+5];
int main(){
scanf("%d %d %lld",&n,&m,&C);
for(int i=0;i<=n;i++){
scanf("%lld",&a[i]);
a[i]%=C;
}
for(int i=0;i<=m;i++){
scanf("%lld",&b[i]);
b[i]%=C;
}
N=1,L=0;
while(N<n+m+1){
N*=2;
L++;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
fmul(a,b,c[0],N,m1);
fmul(a,b,c[1],N,m2);
fmul(a,b,c[2],N,m3);
for(int i=0;i<n+m+1;i++){
printf("%lld ",CRT(c[0][i],c[1][i],c[2][i]));
}
}
容易发现,三模数NTT需要9次FFT,不是很优秀
拆系数FFT
我们之前讨论的优化都是针对FFT的,那不妨尝试用FFT解决任意模数NTT
最简单的想法是不取模,FFT完再取模。但是上文提到数值过大,long double
会丢失精度。
int128
是一个方法,但在OI比赛中不一定能使用。所以需要拆系数。
设\(M_0=[\sqrt{M}]\)
相当于把模数换成\(M_0\),降低大小。
代入对应的多项式
这不就是我们提到的\((A+B)(C+D)\)形的卷积吗?
由于\(k,b\)都不超过\(2^{15}\),于是就不容易被卡精度了。实际操作中我们不必取\(M_0=\sqrt{M}\),直接取\(M_0=2^{15}\)即可。这样取模运算可以换成位运算,进一步减小常数。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("(%lf,%lf)\n",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
ll mod;
void mul(ll *ina,ll *inb,ll *inc,int n){
static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5];
static com p[maxn+5],q[maxn+5];
static com r[maxn+5],s[maxn+5];
for(int i=0;i<n;i++){
ina[i]=(ina[i]+mod)%mod;
inb[i]=(inb[i]+mod)%mod;
a[i]=ina[i]>>15;
b[i]=ina[i]&((1<<15)-1);
c[i]=inb[i]>>15;
d[i]=inb[i]&((1<<15)-1);
}
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
for(int i=0;i<n;i++){
p[i]=com(a[i],b[i]);//打包A,B
q[i]=com(c[i],d[i]);//打包C,D
}
fft(p,n);
fft(q,n);
for(int i=0;i<n;i++){
// p[i].print();
int j=(i==0?0:n-i);
//得到DFT(A),DFT(B),DFT(C),DFT(D)
com da=(p[i]+p[j].conj())*0.5;
com db=(p[i]-p[j].conj())*com(0,-0.5);
com dc=(q[i]+q[j].conj())*0.5;
com dd=(q[i]-q[j].conj())*com(0,-0.5);
r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
}
fft(r,n);
fft(s,n);
for(int i=0;i<n;i++){
ll ac,ad,bc,bd;
ac=(ll)(r[i].real/n+0.5)%mod;
ad=(ll)(r[i].imag/n+0.5)%mod;
bc=(ll)(s[i].real/n+0.5)%mod;
bd=(ll)(s[i].imag/n+0.5)%mod;
inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod;
}
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d %lld",&n,&m,&mod);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<=n+m+1){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]);
}