多项式乘法入门 FFT/NTT
首先要知道多项式是什么东西。多项式是形如 \(A(x)=a_0+a_1x+a_2x^2+a_3x^3+...+a_{N}x^N\) (即 \(A(x)=\sum_{i=0}^nx_i\))的整式。
两个多项式相加:\(C(x)=A(x)+B(x)=\sum\limits_{i=1}^N(a_i+b_i)x^i\)
两个多项式相减:\(C(x)=A(x)-B(x)=\sum\limits_{i=1}^N(a_i-b_i)x^i\)
两个次数分别为 \(n,m\) 的多项式相乘:\(C(x)=A(x)\times B(x)=\sum\limits_{i=0}^N\sum\limits_{j=0}^Ma_ib_jx^{i+j}\)。那么多项式 \(C(x)\) 的系数 \(c_i=\sum\limits_{j=0}^i a_jb_{i-j}\) 。
卷积:对于数列 \(A,B,C\),形如 \(c_i=\sum\limits_{j\oplus k=i}a_jb_k\) 的方式称为卷积(\(\oplus\) 为某一运算符)。如多项式乘法属于加卷积。
\(\text{FFT}\)
点值表示法
多项式 \(A(x)=\sum\limits_{i=1}^na_ix^i\) 的点值表示:带入任意一个 \(x\),即可得到对应的值。
如 \(A(x)=3x^2+5x+1\),带入 \(x=2\),可得 \(A(2)=3\times2^2+5\times 2+1=23\)。
我们计算两个次数分别为 \(n,m\) 多项式相乘 \(A(x)B(x)\),可以先挑 \((n+m+1)\) 个 \(x\) 计算点值,把对应位置相乘。如 \(A(x)=5x+3,B(x)=x+6\),选择 \(x=1,2,3\) 计算,\(A(1)=8,A(2)=13,A(3)=18,B(1)=7,B(2)=8,B(3)=9\),对应位置相乘可得 \(8\times 7=56,13\times 8=104,18\times 9=162\)。
然后把这 \((n+m+1)\) 插值计算 \((n+m)\) 次多项式。按上面的例子,\(x=1,2,3\) 时值为 \(56,104,162\),可插值得到多项式 \(C(x)=5x^2+33x+18\)。
接下来讲如何快速利用特殊的 \(x\) 来优化计算点值的过程。
复平面与单位根
在复平面上,横轴表示实部,纵轴表示虚部,以原点为起点的向量即可表示一个复数。
单位根:定义 \(\omega_n\) 为一个复数,使得 \((\omega_n)^n=1\) 且 \(\omega_n\neq1\),比如 \(\omega_2=-1,\omega_4=-i\)。在复平面上,以原点为圆心作一个半径为 \(1\) 的圆,然后以数字 \(1\) 对应的向量为起点,逆时针旋转,把圆分成 \(n\) 等分,每等分划分的向量即为 \(1,\omega_n,\omega_n^2,\omega_n^3,...,\omega_n^{n-1}\)(\(\omega_n^k=(\omega_n)^k\))。
计算:\(\omega_{n}^k=\cos k\frac{2\pi} n+i\sin k\frac{2\pi} n\)
几个性质:
-
\(\omega_n^{n+k}=\omega_n^k\)
-
\(\omega_n^{\frac n2}=-1\)
-
\(\omega_n^{-k}=\omega_n^{n-k}\)
(自己画图就能明白)
快速傅里叶变换
我们希望带入 \(x=1,\omega_n,\omega_n^2,\omega_n^3,...,\omega_n^{n-1}\):。
对于多项式 \(A(x)\),我们把每一项系数按次数分成奇偶两类:
显然,\(A(x)=A_0(x^2)+xA_1(x^2)\)。
带入 \(\omega_n^k(k<\frac n2)\) 得:
带入 \(\omega_n^{k+\frac n2}(k<\frac n2)\) 得:
然后我们发现带入两个值后,多项式 \(A_0\) 和 \(A_1\) 里面的值都是 \(\omega_n^{2k}\),而只是右边的项的符号不同!!!!!!
注意 \(\omega_n^{2k}=\omega_{\frac n2}^k\),于是只需要处理 \(A_0,A_1\) 在 \(x=1,\omega_{\frac n2},\omega_{\frac n2}^2,...,\omega_{\frac n2}^{\frac n2-1}\) 的点值即可!
这样可以直接分治,时间复杂度 \(O(n\log n)\)。当然,我们需要预处理出一个 \(>N+M\)(\(N,M\) 为两个多项式次数)的 \(2\) 的幂 的数字 \(n\)。
快速傅里叶逆变换
不要问我为什么,我们只需要把得到的点值对应位置相乘后,重新 \(\text{FFT}\) 一遍(过程一样),然后把 \(1\) 次项到 \(n-1\) 次项全部翻转一下即可(注意不是 \(0\) 次)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10;
ll n,m;
const double pi=acos(-1.0);
struct Complex
{
double x,y;
Complex (double xx=0,double yy=0) {x=xx; y=yy;}
const Complex operator+(const Complex tmp) const
{
return (Complex){x+tmp.x,y+tmp.y};
}
const Complex operator-(const Complex tmp) const
{
return (Complex){x-tmp.x,y-tmp.y};
}
const Complex operator*(const Complex tmp) const
{
return (Complex){x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x};
}
}a[maxn],b[maxn],c[maxn];
void fft(ll n,Complex *a)
{
if(n==1) return;
Complex a1[n>>1], a2[n>>1];
ll m=n>>1;
for(ll i=0;i<m;i++) a1[i]=a[2*i];
for(ll i=0;i<m;i++) a2[i]=a[2*i+1];
fft(m,a1);
fft(m,a2);
Complex W(cos(1.0*pi/m),sin(1.0*pi/m)), w(1,0);
for(ll i=0;i<m;i++,w=w*W)
{
a[i]=a1[i]+w*a2[i];
a[i+m]=a1[i]-w*a2[i];
}
}
int main(){
scanf("%lld%lld",&n,&m);
for(ll i=0;i<=n;i++)
{
scanf("%lf",&a[i].x);
}
for(ll i=0;i<=m;i++)
{
scanf("%lf",&b[i].x);
}
ll s=n+m;
ll p=1;
while(p<=s) p<<=1;
fft(p,a);
fft(p,b);
for(ll i=0;i<p;i++) c[i]=a[i]*b[i];
fft(p,c);
reverse(c+1,c+p);
for(ll i=0;i<=s;i++) printf("%lld ",(ll)(c[i].x/p+0.5));
return 0;
}
递归转递推(蝴蝶变换)
函数内开数组、递归…… 还有许多地方可以优化。
思考如何转递推。
我们尝试逐步模拟分类过程:
次数 | \(0\) | \(1\) | \(2\) | \(3\) | \(4\) | \(5\) | \(6\) | \(7\) |
---|---|---|---|---|---|---|---|---|
次数 | \(0\) | \(2\) | \(4\) | \(6\) | \(1\) | \(3\) | \(5\) | \(7\) |
次数 | \(0\) | \(4\) | \(2\) | \(6\) | \(1\) | \(5\) | \(3\) | \(7\) |
最后一行二进制分别为 \(000,100,010,110,001,101,011,111\)
反过来为 \(000,001,010,011,100,101,110,111\)
发现反过来是顺序的!
预处理 \(i\) 二进制反过来为 \(r_i\),按 \(r_i\) 排序。
然后直接递推即可。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10;
ll n,m,r[maxn];
const double pi=acos(-1);
struct Complex
{
double x,y;
Complex(double xx=0,double yy=0)
{
x=xx; y=yy;
}
const Complex operator+(const Complex tmp) const
{
return (Complex){x+tmp.x,y+tmp.y};
}
const Complex operator-(const Complex tmp) const
{
return (Complex){x-tmp.x,y-tmp.y};
}
const Complex operator*(const Complex tmp) const
{
return (Complex){x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x};
}
}a[maxn],b[maxn],c[maxn],tmp[maxn];
void fft(ll n,Complex *a)
{
for(ll i=0;i<(1<<n);i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(ll i=0;i<n;i++)
{
Complex W(cos(pi/(1<<i)),sin(pi/(1<<i)));
tmp[0]=(Complex){1,0};
for(ll j=1;j<(1<<i);j++) tmp[j]=tmp[j-1]*W;
for(ll j=0;j<(1<<n);j++)
if(!(j&(1<<i)))
{
Complex a1=a[j], a2=a[j+(1<<i)], t=a2*tmp[j&((1<<i)-1)];
a[j]=a1+t;
a[j+(1<<i)]=a1-t;
}
}
}
int main()
{
scanf("%lld%lld",&n,&m);
for(ll i=0;i<=n;i++)
{
scanf("%lf",&a[i].x);
}
for(ll i=0;i<=m;i++)
{
scanf("%lf",&b[i].x);
}
ll k=n+m, p=1;
while((1<<p)<=k) ++p;
for(ll i=1;i<(1<<p);i++)
r[i]=(r[i>>1]>>1)|((i&1)<<p-1);
fft(p,a);
fft(p,b);
for(ll i=0;i<(1<<p);i++) c[i]=a[i]*b[i];
fft(p,c);
reverse(c+1,c+(1<<p));
for(ll i=0;i<=k;i++) printf("%lld ",(ll)(c[i].x/(1<<p)+0.5));
return 0;
}
\(\text{NTT}\)
\(\text{FFT}\) 能够在 \(O(n\log n)\) 的时间内完成对两个多项式相乘,但局限性很大。首先大量的单位根会影响精度,其次不能取模。我们需要一种更高效的方法。
原根
根据欧拉定理,如果 \(a,p\) 互质,那么 \(a^{\varphi(p)}\equiv 1\pmod p\)。
原根:对于互质的 \(a,p\),如果不存在 \(m\),满足 \(m<\varphi(p)\),且 \(a^m\equiv 1\pmod p\),那么称 \(a\) 为 \(p\) 的一个原根。
我们把 \(p\) 写成 \(2^xb+1\) 的形式,若 \(n\) 是一个以 \(2\) 为底的幂(\(n\le2^x\)),记 \(g_n=a^{\frac{p-1}n}\)(\(a\) 是原根),那么满足以下性质:
A: \(g_n^n=(a^{\frac{p-1}n})^n=a^{p-1}\equiv 1\pmod p\)
B: \(g_n^{\frac n 2}=a^{\frac{p-1}2}\equiv -1\pmod p\)
我们容易得到类似于单位根的几个性质:
-
\(g_n^{n+k}=g_n^n\times g_n^k\equiv1\times g_n^k\equiv g_n^k\pmod p\)
-
\(g_n^{\frac n 2}\equiv -1\pmod p\)
-
\(g_n^{-k}\equiv 1\times g_n^{-k}\equiv g_n^n\times g_n^{-k}\equiv g_n^{n-k}\pmod p\)
这和单位根不一模一样吗!!!只不过在模意义下而已。
模数限制
通常情况下,模数为 \(998244353=119\times 2^{23}+1\),此时原根可取 \(3\)。\(10^9+7\) 基本做不了,因为 \(10^9+7=500000003\times 2^1+1\),而 \(1\) 太小。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10, mod=998244353, g=3;
ll n,m,a[maxn],b[maxn],c[maxn],r[maxn];
ll power(ll a,ll b)
{
ll ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void ntt(ll n,ll *a)
{
for(ll i=0;i<(1<<n);i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(ll i=0;i<n;i++)
{
ll G=power(g,(mod-1)/(1<<i+1));
for(ll j=0;j<(1<<n);j+=(1<<i+1))
{
for(ll k=0,g=1;k<(1<<i);k++,g=g*G%mod)
{
ll a1=a[j+k], a2=a[(1<<i)+j+k], t=a2*g%mod;
a[j+k]=(a1+t)%mod;
a[(1<<i)+j+k]=(a1-t+mod)%mod;
}
}
}
}
int main()
{
scanf("%lld%lld",&n,&m);
for(ll i=0;i<=n;i++) scanf("%lld",a+i);
for(ll i=0;i<=m;i++) scanf("%lld",b+i);
ll s=n+m, p=0;
while((1<<p)<=s) ++p;
for(ll i=1;i<(1<<p);i++)
r[i]=(r[i>>1]>>1)|((i&1)<<p-1);
ntt(p,a);
ntt(p,b);
for(ll i=0;i<(1<<p);i++) c[i]=a[i]*b[i]%mod;
ntt(p,c);
reverse(c+1,c+(1<<p));
ll inv=power(1<<p,mod-2);
for(ll i=0;i<=s;i++) printf("%lld ",c[i]*inv%mod);
return 0;
}
分治 \(\text{FFT}\)
问题形式
已知 \(f_0=1\),而 \(f_i=\sum\limits_{j=0}^{i-1}f_j\times g_{i-j}\),\(g\) 是给定的,求 \(f_{0...n-1}\)。
分治求解
考虑 cdq 分治,把 \([0,n-1]\) 分成 \([0,mid],[mid+1,n-1]\) 两部分。计算左边对右边的贡献,我们其实可以直接把 \([0,mid]\) 的多项式和 \(g\) 相乘,加到右边即可。
时间复杂度 \(O(n\log^2n)\)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e5+10, mod=998244353, g=3;
ll n,a[maxn],f[maxn],rev[maxn],w1[maxn],w2[maxn],w[maxn];
ll power(ll a,ll b)
{
ll ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void ntt(ll n,ll *a)
{
for(ll i=0;i<(1<<n);i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<n-1);
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(ll i=0;i<n;i++)
{
ll G=power(g,(mod-1)/(1<<i+1));
for(ll j=0;j<(1<<n);j+=(1<<i+1))
{
for(ll k=0,g0=1;k<(1<<i);k++,g0=g0*G%mod)
{
ll a1=a[j+k], a2=a[(1<<i)+j+k], t=a2*g0%mod;
a[j+k]=(a1+t)%mod;
a[(1<<i)+j+k]=(a1-t+mod)%mod;
}
}
}
}
void cdq(ll l,ll r)
{
if(l==r) return;
ll mid=l+r>>1;
cdq(l,mid);
for(ll i=0;i<=r-l;i++) w1[i]=a[i];
for(ll i=0;i<=mid-l;i++) w2[i]=f[l+i];
ll s=r-l+mid-l, p=0;
while((1<<p)<=s) ++p;
ntt(p,w1);
ntt(p,w2);
for(ll i=0;i<(1<<p);i++) w[i]=w1[i]*w2[i]%mod;
ntt(p,w);
reverse(w+1,w+(1<<p));
ll inv=power(1<<p,mod-2);
for(ll i=mid+1;i<=r;i++)
{
f[i]=(f[i]+w[i-1-l]*inv)%mod;
}
for(ll i=0;i<(1<<p);i++) w1[i]=w2[i]=w[i]=0;
cdq(mid+1,r);
}
int main()
{
scanf("%lld",&n);
for(ll i=0;i<n-1;i++) scanf("%lld",a+i);
f[0]=1;
cdq(0,n-1);
for(ll i=0;i<n;i++) printf("%lld ",f[i]);
return 0;
}