浅析卷积
0. 写在前面
从粗斜体到分界线中的内容可以跳过。
本文中一些算法的简历:
简写 | 全称 | 中文 | |
---|---|---|---|
FWT | Fast Walsh Transformation | 快速沃尔什变换 | Fast Wonderful TLE |
FFT | Fast Fourier Transformation | 快速傅里叶变换 | Fast Fantastic TLE |
NTT | Number Theory Transformation | 快速数论变换 | Natural Talented TLE |
1. 卷积的定义
\(c\) 称为 \(a,b\) 的卷积,当 \(*\) 指不同的运算符时, \(c\) 有不同的求法,现在分类讨论。
2. 当 \(*\) 指 \(max/min\)
可以通过简单的前/后缀和计算,以下是 \(max\) 卷积的代码:
#include<stdio.h>
int N,a[1000005],b[1000005],c[1000005],A[1000005],B[1000005],C[1000005];
int main(){
scanf("%d",&N);
for(int i=1;i<=N;i++)
scanf("%d",a+i);
for(int i=1;i<=N;i++)
scanf("%d",b+i);
for(int i=1;i<=N;i++){
A[i]=a[i]+A[i-1];
B[i]=b[i]+A[i-1];
}
for(int i=1;i<=N;i++)
C[i]=A[i]*B[i];
for(int i=N;i>=1;i--)
c[i]=C[i]-C[i-1];
for(int i=1;i<=N-1;i++)
printf("%d ",c[i]);
printf("%d\n",c[N]);
return 0;
}
\(min\)卷积与之类似。
为什么它是正确的?
讨论 \(max\) 卷积
即
程序中
所以
易得
与理论答案相符
由此我们可以总结出一点经验,求卷积的流程往往是这样:
- 用某种变换将\(a_i,b_i\)变成\(A_i,B_i\)
- \(C_i=A_i\times B_i\)
- 用其逆变换将\(C_i\)变成\(c_i\)得到答案
3. 当 \(*\) 指 \(\vee/\wedge\)
在这篇文章中,如果涉及到带 \(lg_N\) 复杂度的卷积变换, \(N=2^k(k\in\mathbb{N})\) 。实际实现时在高位补 \(0\)
3.1 用向量表示数
一个\(k\)位二进制数可以表示成一个\(k\)维向量。例如当\(k=3\)时:
数值 | 向量表示 | 数值 | 向量表示 |
---|---|---|---|
0 | \(\left \{ 0,0,0 \right \}\) | 4 | \(\left \{ 1,0,0 \right \}\) |
1 | \(\left \{ 0,0,1 \right \}\) | 5 | \(\left \{ 1,0,1 \right \}\) |
2 | \(\left \{ 0,1,0 \right \}\) | 6 | \(\left \{ 1,1,0 \right \}\) |
3 | \(\left \{ 0,1,1 \right \}\) | 7 | \(\left \{ 1,1,1 \right \}\) |
3.2 \(\vee\)的实质
我们把用向量表示的数字\(\vee\),例如
由此可知\(\vee\)的本质是按位 \(max\) ,所以 \(\vee\) 卷积的变换就是按位前缀和。它的逆变换其实是一个脑筋急转弯,只要把循环倒过来,把\(+=\)改成\(-=\)就可以了。代码如下:
#include<stdio.h>
int n,k,N,a[100005],b[100005],c[100005];
int main(){
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",a+i);
for(int i=0;i<n;i++)
scanf("%d",b+i);
for(N=1;N<n;N<<=1,k++);
for(int i=0;i<k;i++)
for(int j=0;j<N;j++)
if((j&(1<<i))==0)
a[j+(1<<i)]+=a[j];
for(int i=0;i<k;i++)
for(int j=0;j<N;j++)
if((j&(1<<i))==0)
b[j+(1<<i)]+=b[j];
for(int i=0;i<N;i++)
c[i]=a[i]*b[i];
for(int i=k-1;i>=0;i--)
for(int j=N-1;j>=0;j--)
if((j&(1<<i))==0)
c[j+(1<<i)]-=c[j];
for(int i=0;i<=N-2;i++)
printf("%d ",c[i]);
printf("%d\n",c[N-1]);
return 0;
}
\(\wedge\)卷积与之类似。
4. 当\(*\)指\(\bigoplus\)
4.1 千里之行,始于足下:N=2
有小学知识可知:
我们将\(a_i,b_i\)带入
听说是正确的,于是拓展到高维
4.2 一句不是废话的废话
虽然我们已经发现了正解,我们还是从向量的角度看一下。比如
再写一遍
发现\(\bigoplus\)其实是二进制无进位加法。
4.3 具体实现
\(\bigoplus\)卷积的变换代码与\(\vee/\wedge\)卷积的代码略有不同,但\(\vee/\wedge\)卷积的代码也可以写成这种形式,具体题意见洛谷P4717【模板】快速沃尔什变换
#include<stdio.h>
const int p=998244353;
inline int power(int a,int k){
int ans=1;
for(;k;a=1LL*a*a%p,k>>=1)
if(k&1)
ans=1LL*ans*a%p;
return ans;
}
int len,N,a[4][300005],b[4][300005];
void FWT(int* a,int op,int flag){
int x,y;
if(op==3&&flag==-1){
for(int i=N>>1;i>0;i>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++){
x=a[j+k];
y=a[i+j+k];
a[j+k]=1LL*(x+y)*power(2,p-2)%p;
a[i+j+k]=(1LL*(x-y)*power(2,p-2)%p+p)%p;
}
return;
}
for(int i=1;i<N;i<<=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++){
x=a[j+k];
y=a[i+j+k];
if(op==1){
if(flag==1){
a[j+k]=x;
a[i+j+k]=(x+y)%p;
}
else{
a[j+k]=x;
a[i+j+k]=((y-x)%p+p)%p;
}
}
if(op==2){
if(flag==1){
a[j+k]=(x+y)%p;
a[i+j+k]=y;
}
else{
a[j+k]=((x-y)%p+p)%p;
a[i+j+k]=y;
}
}
if(op==3){
a[j+k]=(x+y)%p;
a[i+j+k]=((x-y)%p+p)%p;
}
}
}
int main(){
scanf("%d",&len);
N=power(2,len);
for(int i=0;i<N;i++){
scanf("%d",a[1]+i);
a[3][i]=a[2][i]=a[1][i];
}
for(int i=0;i<N;i++){
scanf("%d",b[1]+i);
b[3][i]=b[2][i]=b[1][i];
}
for(int j=1;j<=3;j++){
FWT(a[j],j,1);
FWT(b[j],j,1);
}
for(int i=0;i<N;i++)
for(int j=1;j<=3;j++)
a[j][i]=1LL*a[j][i]*b[j][i]%p;
for(int j=1;j<=3;j++)
FWT(a[j],j,-1);
for(int j=1;j<=3;j++){
for(int i=0;i<N-1;i++)
printf("%d ",a[j][i]);
printf("%d\n",a[j][N-1]);
}
return 0;
}
5. 当\(*\)指\(+\)
公式恐惧症患者请果断按下Ctrl+W
以发起正当防卫
5.1 求多项式乘法的新方法
定义\(N\)次多项式\(g,h\),我们可以选取\(x_{0...2\times N}\)带入\(g,h\)得到\(G_{0...2\times N},H_{0...2\times N}\),将G和H逐位相乘得到\(F\),最后将\(F\)消元得到\(f\),\(f=g\times h\)。
不幸的是,带入多项式需要 \(O(n^2)\) 。What's worse,高斯消元需要 \(O(n^3)\) 。
5.2 选择带入的数
显然,瓶颈在带入的数的选择上。那我们需要带入怎样的数带入呢?
5.2.1 复数
我们知道, \(x^2=-1\) 无实数解,但我们定义 \(i^2=-1\) 。
复数是所有能够写成 \(a+i\times b(a,b\in \mathbb{R})\) 的数的集合,该集合记作 \(\mathbb{C}\)。
5.2.2 复数的性质
- 对于任意整数 \(n\geqslant0,k\geqslant0,d\geqslant0\) 有
- 对于任意整数 \(n\geqslant0,k\geqslant0\) 有
- 对于任意整数 \(n\geqslant0\) 有
- 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\) 有
- 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\) 有
5.3 开始带入!
我们定义将\(<\omega_N^0,\omega_N^1, ... ,\omega_N^{N-1}>\)带入\(<a_0,a_1, ... ,a_{N-1}>\)的结果设为\(<A_0,A_1, ... ,A_{N-1}>\)。
根据定义得
由性质5得
我们把\(i\)的定义域从 \([0,N-1]\) 变成 \([0,\frac{N}{2}-1]\)
奇偶分类
于是我们惊喜地发现,这可以分治做。
为什么需要奇偶分类呢?也许这就是傅里叶的伟大之处吧。
5.4 逆变换
那么FFT的逆变换怎么写呢?
令人惊讶的是,恰有
为什么它是正确的?
5.5 蝴蝶变换
我们将\(N=8\)的分治情况手动模拟一下,可以得到:
区间大小 | \(id_0\) | \(id_1\) | \(id_2\) | \(id_3\) | \(id_4\) | \(id_5\) | \(id_6\) | \(id_7\) |
---|---|---|---|---|---|---|---|---|
8 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
4 | 0 | 2 | 4 | 6 | 1 | 3 | 5 | 7 |
2 | 0 | 4 | 2 | 6 | 1 | 5 | 3 | 7 |
1 | 0 | 4 | 2 | 6 | 1 | 5 | 3 | 7 |
把这个表格用二进制描述
区间大小 | \(id_0\) | \(id_1\) | \(id_2\) | \(id_3\) | \(id_4\) | \(id_5\) | \(id_6\) | \(id_7\) |
---|---|---|---|---|---|---|---|---|
8 | 000 | 001 | 010 | 011 | 100 | 101 | 110 | 111 |
4 | 000 | 010 | 100 | 110 | 001 | 011 | 101 | 111 |
2 | 000 | 100 | 010 | 110 | 001 | 101 | 011 | 111 |
1 | 000 | 100 | 010 | 110 | 001 | 101 | 011 | 111 |
我们发现表格的第一行和最后一行二进制是反转的,这样我们就发现了\(FFT\)的非递归写法,代码如下:
#include<math.h>
#include<stdio.h>
#include<algorithm>
using namespace std;
const double pi=acos(-1.0);
int n,m,res=0,N=1,len,revers[2097160];
long long ans[2097160];
int i,j,k,l;
struct node{
double x,y;
node(double x=0,double y=0):x(x),y(y){}
node operator*(const node &b){
return node(x*b.x-y*b.y,x*b.y+y*b.x);
}
node operator+(const node &b){
return node(x+b.x,y+b.y);
}
node operator-(const node &b){
return node(x-b.x,y-b.y);
}
}a[2097160],b[2097160],T,t,x,y;
void FFT(node *a,double flag){
for(i=0;i<N;i++)
if(i<revers[i])
swap(a[i],a[revers[i]]);
for(j=1;j<N;j<<=1){
T=node(cos(pi/j),flag*sin(pi/j));
for(k=0;k<N;k+=(j<<1)){
t=node(1,0);
for(l=0;l<j;l++,t=t*T){
x=a[k+l],y=t*a[k+j+l];
a[k+l]=x+y;
a[k+j+l]=x-y;
}
}
}
}
int main(){
scanf("%d%d",&n,&m);
n++;
m++;
for(i=0;i<n;i++)
scanf("%lf",&a[i].x);
for(i=0;i<m;i++)
scanf("%lf",&b[i].x);
for(;N<max(n,m)<<1;N<<=1,len++);
for(i=0;i<=N;i++)
revers[i]=(revers[i>>1]>>1)|((i&1)<<(len-1));
FFT(a,1);
FFT(b,1);
for(i=0;i<=N;i++)
a[i]=a[i]*b[i];
FFT(a,-1);
for(i=0;i<=N;i++)
ans[i]+=(long long)(a[i].x/N+0.5);
for(;!ans[N]&&N;N--);
N++;
for(i=0;i<n+m-2;i++)
printf("%lld ",ans[i]);
printf("%lld\n",ans[n+m-2]);
return 0;
}
5.6 精度问题
把上面的代码加入模操作后提交到P4245里去发现光荣\(\color{red}\text{WA}\)。
然后就发现FFT有精度问题,那么如何避免呢?
5.6.1 原根
如果\(g\)的\(0\) ~ \(\phi(p)-1\)在模\(p\)意义下正好遍历了\(1\) ~ \(p-1\)中与\(p\)互质的\(\phi(p)\)个数,那么称\(g\)为\(p\)的原根。
当p为质数时,我们发现如果用\(g^\frac{p-1}{N}\)代替单位复根(记为\(g_N\)),它满足单位复根的所有性质:
- 对于任意整数 \(n\geqslant0,k\geqslant0,d\geqslant0\) 有
- 对于任意整数 \(n\geqslant0,k\geqslant0\) 有
- 对于任意整数 \(n\geqslant0\) 有
- 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\) 有
- 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\) 有
5.6.2 能选的质数
一般情况下有三个质数可选:
- \(469762049=7\times 2^{26}+1\)
- \(998244353=119\times 2^{23}+1\)
- \(1004535809=749\times 2^{21}+1\)
当\(p\)取上面几个质数时,\(g=3\),\(p-1\)中有很多\(2\)的因子,FFT中\(N\)又都是\(2\)的次幂,所以上面三个质数一定要记下来。
代码如下:
#include<stdio.h>
#include<algorithm>
using namespace std;
const long long p=998244353,g=3,invg=332748118;
int n,m,res=0,N=1,len,revers[2097160];
long long ans[2097160],a[2097160],b[2097160],T,t,x,y;
int i,j,k,l;
inline long long power(long long a,long long k,long long p){
long long ans=1,t=a;
for(;k;k>>=1,t=t*t%p)
if(k&1)
ans=ans*t%p;
return ans;
}
void NTT(long long *a,long long flag){
for(i=0;i<N;i++)
if(i<revers[i])
swap(a[i],a[revers[i]]);
for(j=1;j<N;j<<=1){
T=power(flag==1?g:invg,(p-1)/j/2,p);
for(k=0;k<N;k+=(j<<1)){
t=1;
for(l=0;l<j;l++,t=t*T%p){
x=a[k+l],y=t*a[k+j+l]%p;
a[k+l]=(x+y)%p;
a[k+j+l]=((x-y)%p+p)%p;
}
}
}
}
int main(){
scanf("%d%d",&n,&m);
n++;
m++;
for(i=0;i<n;i++){
scanf("%lld",a+i);
a[i]=a[i]%p;
}
for(i=0;i<m;i++){
scanf("%lld",b+i);
b[i]=b[i]%p;
}
for(;N<max(n,m)<<1;N<<=1,len++);
for(i=0;i<=N;i++)
revers[i]=(revers[i>>1]>>1)|((i&1)<<(len-1));
NTT(a,1);
NTT(b,1);
for(i=0;i<=N;i++)
a[i]=a[i]*b[i]%p;
NTT(a,-1);
for(i=0;i<=N;i++)
ans[i]=a[i]*power(N,p-2,p)%p;
for(i=0;i<n+m-2;i++)
printf("%lld ",ans[i]);
printf("%lld\n",ans[n+m-2]);
return 0;
}
5.7 换个角度看\(\bigoplus\)卷积
我们再回忆一下4.2节
的内容,我们在做\(\bigoplus\)卷积时,其实可以做\(lg_N\)遍\(FFT\),然后又因为\(\omega_2^0=1\),\(\omega_2^1=-1\),就可以得到4.1节
的结论了。