浅析卷积

0. 写在前面

从粗斜体到分界线中的内容可以跳过。

本文中一些算法的简历:

简写 全称 中文 搞笑
FWT Fast Walsh Transformation 快速沃尔什变换 Fast Wonderful TLE
FFT Fast Fourier Transformation 快速傅里叶变换 Fast Fantastic TLE
NTT Number Theory Transformation 快速数论变换 Natural Talented TLE

1. 卷积的定义

\[c_i=\sum_{j*k=i}{a_j\times b_k} \]

\(c\) 称为 \(a,b\) 的卷积,当 \(*\) 指不同的运算符时, \(c\) 有不同的求法,现在分类讨论。

2. 当 \(*\)\(max/min\)

可以通过简单的前/后缀和计算,以下是 \(max\) 卷积的代码:

#include<stdio.h>
int N,a[1000005],b[1000005],c[1000005],A[1000005],B[1000005],C[1000005];
int main(){
  scanf("%d",&N);
  for(int i=1;i<=N;i++)
    scanf("%d",a+i);
  for(int i=1;i<=N;i++)
    scanf("%d",b+i);
  for(int i=1;i<=N;i++){
    A[i]=a[i]+A[i-1];
    B[i]=b[i]+A[i-1];
  }
  for(int i=1;i<=N;i++)
    C[i]=A[i]*B[i];
  for(int i=N;i>=1;i--)
    c[i]=C[i]-C[i-1];
  for(int i=1;i<=N-1;i++)
    printf("%d ",c[i]);
  printf("%d\n",c[N]);
  return 0;
}

\(min\)卷积与之类似。
为什么它是正确的?
讨论 \(max\) 卷积

\[c_i=\sum_{max(j,k)=i}{a_j\times b_k} \]

\[c_i=a_i\times b_i+\sum_{j=1}^{i-1}{(a_j\times b_i+b_j\times a_i)} \]

程序中

\[A_i=\sum_{j=1}^{i}{a_j}\text{ }\text{ }\text{ }\text{ }B_i=\sum_{j=1}^{i}{b_j} \]

所以

\[C_i=\sum_{j=1}^{i}{\sum_{k=1}^{i}{a_j\times b_k}} \]

易得

\[c_i=C_i-C_{i-1}=\sum_{j=1}^{i}{(a_j\times b_i)}+\sum_{j=1}^{i-1}{(a_i\times b_j)} \]

与理论答案相符


由此我们可以总结出一点经验,求卷积的流程往往是这样:

  1. 用某种变换将\(a_i,b_i\)变成\(A_i,B_i\)
  2. \(C_i=A_i\times B_i\)
  3. 用其逆变换将\(C_i\)变成\(c_i\)得到答案

3. 当 \(*\)\(\vee/\wedge\)

在这篇文章中,如果涉及到带 \(lg_N\) 复杂度的卷积变换, \(N=2^k(k\in\mathbb{N})\) 。实际实现时在高位补 \(0\)

3.1 用向量表示数

一个\(k\)位二进制数可以表示成一个\(k\)维向量。例如当\(k=3\)时:

数值 向量表示 数值 向量表示
0 \(\left \{ 0,0,0 \right \}\) 4 \(\left \{ 1,0,0 \right \}\)
1 \(\left \{ 0,0,1 \right \}\) 5 \(\left \{ 1,0,1 \right \}\)
2 \(\left \{ 0,1,0 \right \}\) 6 \(\left \{ 1,1,0 \right \}\)
3 \(\left \{ 0,1,1 \right \}\) 7 \(\left \{ 1,1,1 \right \}\)

3.2 \(\vee\)的实质

我们把用向量表示的数字\(\vee\),例如

\[3\vee 6=\left \{ 0,1,1 \right \}\vee\left \{ 1,1,0 \right \}=\left \{ 1,1,1 \right \}=7 \]

由此可知\(\vee\)的本质是按位 \(max\) ,所以 \(\vee\) 卷积的变换就是按位前缀和。它的逆变换其实是一个脑筋急转弯,只要把循环倒过来,把\(+=\)改成\(-=\)就可以了。代码如下:

#include<stdio.h>
int n,k,N,a[100005],b[100005],c[100005];
int main(){
  scanf("%d",&n);
  for(int i=0;i<n;i++)
    scanf("%d",a+i);
  for(int i=0;i<n;i++)
    scanf("%d",b+i);
  for(N=1;N<n;N<<=1,k++);
  for(int i=0;i<k;i++)
    for(int j=0;j<N;j++)
      if((j&(1<<i))==0)
        a[j+(1<<i)]+=a[j];
  for(int i=0;i<k;i++)
    for(int j=0;j<N;j++)
      if((j&(1<<i))==0)
        b[j+(1<<i)]+=b[j];
  for(int i=0;i<N;i++)
    c[i]=a[i]*b[i];
  for(int i=k-1;i>=0;i--)
    for(int j=N-1;j>=0;j--)
      if((j&(1<<i))==0)
        c[j+(1<<i)]-=c[j];
  for(int i=0;i<=N-2;i++)
    printf("%d ",c[i]);
  printf("%d\n",c[N-1]);
  return 0;
}

\(\wedge\)卷积与之类似。

4. 当\(*\)\(\bigoplus\)

4.1 千里之行,始于足下:N=2

有小学知识可知:

\[<a_0,a_1>\Rightarrow <a_0+a_1,a_0-a_1> \]

\[<A_0,A_1>\Rightarrow <\frac{A_0+A_1}{2},\frac{A_0-A_1}{2}> \]

我们将\(a_i,b_i\)带入

\[C_0=A_0\times B_0=(a_0+a_1)\times (b_0+b_1)=a_0\times b_0+a_0\times b_1+a_1\times b_0+a_1\times b_1 \]

\[C_1=A_1\times B_1=(a_0-a_1)\times (b_0-b_1)=a_0\times b_0-a_0\times b_1-a_1\times b_0+a_1\times b_1 \]

\[c_0=\frac{C_0+C_1}{2}=a_0\times b_0+a_1\times b_1 \]

\[c_1=\frac{C_0-C_1}{2}=a_0\times b_1+a_1\times b_0 \]

听说是正确的,于是拓展到高维

\[A_i=\sum_{j=0}^{N}{(-1)^{bitcount(i\wedge j)\times a_j}} \]

4.2 一句不是废话的废话

虽然我们已经发现了正解,我们还是从向量的角度看一下。比如

\[3\bigoplus 6=\left \{ 0,1,1 \right \}\bigoplus\left \{ 1,1,0 \right \}=\left \{ 1,0,1 \right \}=5 \]

再写一遍

\[3\bigoplus 6=\left \{ 0,1,1 \right \}\bigoplus\left \{ 1,1,0 \right \}=\left \{ (0+1)\%2,(1+1)\%2,(1+0)\%2 \right \}=\left \{ 1,0,1 \right \}=5 \]

发现\(\bigoplus\)其实是二进制无进位加法。

4.3 具体实现

\(\bigoplus\)卷积的变换代码与\(\vee/\wedge\)卷积的代码略有不同,但\(\vee/\wedge\)卷积的代码也可以写成这种形式,具体题意见洛谷P4717【模板】快速沃尔什变换

#include<stdio.h>
const int p=998244353;
inline int power(int a,int k){
  int ans=1;
  for(;k;a=1LL*a*a%p,k>>=1)
    if(k&1)
      ans=1LL*ans*a%p;
  return ans;
}
int len,N,a[4][300005],b[4][300005];
void FWT(int* a,int op,int flag){
  int x,y;
  if(op==3&&flag==-1){
    for(int i=N>>1;i>0;i>>=1)
      for(int j=0;j<N;j+=i<<1)
        for(int k=0;k<i;k++){
          x=a[j+k];
          y=a[i+j+k];
          a[j+k]=1LL*(x+y)*power(2,p-2)%p;
          a[i+j+k]=(1LL*(x-y)*power(2,p-2)%p+p)%p;
        }
    return;
  }
  for(int i=1;i<N;i<<=1)
    for(int j=0;j<N;j+=i<<1)
      for(int k=0;k<i;k++){
        x=a[j+k];
        y=a[i+j+k];
        if(op==1){
          if(flag==1){
            a[j+k]=x;
            a[i+j+k]=(x+y)%p;
          }
          else{
            a[j+k]=x;
            a[i+j+k]=((y-x)%p+p)%p;
          }
        }
        if(op==2){
          if(flag==1){
            a[j+k]=(x+y)%p;
            a[i+j+k]=y;
          }
          else{
            a[j+k]=((x-y)%p+p)%p;
            a[i+j+k]=y;
          }
        }
        if(op==3){
          a[j+k]=(x+y)%p;
          a[i+j+k]=((x-y)%p+p)%p;
        }
      }
}
int main(){
  scanf("%d",&len);
  N=power(2,len);
  for(int i=0;i<N;i++){
    scanf("%d",a[1]+i);
    a[3][i]=a[2][i]=a[1][i];
  }
  for(int i=0;i<N;i++){
    scanf("%d",b[1]+i);
    b[3][i]=b[2][i]=b[1][i];
  }
  for(int j=1;j<=3;j++){
    FWT(a[j],j,1);
    FWT(b[j],j,1);
  }
  for(int i=0;i<N;i++)
    for(int j=1;j<=3;j++)
      a[j][i]=1LL*a[j][i]*b[j][i]%p;
  for(int j=1;j<=3;j++)
    FWT(a[j],j,-1);
  for(int j=1;j<=3;j++){
    for(int i=0;i<N-1;i++)
      printf("%d ",a[j][i]);
    printf("%d\n",a[j][N-1]);
  }
  return 0;
}

5. 当\(*\)\(+\)

公式恐惧症患者请果断按下Ctrl+W以发起正当防卫

5.1 求多项式乘法的新方法

定义\(N\)次多项式\(g,h\),我们可以选取\(x_{0...2\times N}\)带入\(g,h\)得到\(G_{0...2\times N},H_{0...2\times N}\),将G和H逐位相乘得到\(F\),最后将\(F\)消元得到\(f\),\(f=g\times h\)
不幸的是,带入多项式需要 \(O(n^2)\) 。What's worse,高斯消元需要 \(O(n^3)\)

5.2 选择带入的数

显然,瓶颈在带入的数的选择上。那我们需要带入怎样的数带入呢?

5.2.1 复数

我们知道, \(x^2=-1\) 无实数解,但我们定义 \(i^2=-1\)
复数是所有能够写成 \(a+i\times b(a,b\in \mathbb{R})\) 的数的集合,该集合记作 \(\mathbb{C}\)

5.2.2 复数的性质

  1. 对于任意整数 \(n\geqslant0,k\geqslant0,d\geqslant0\)

\[\omega_{N}^{k}=\omega_{Nd}^{kd} \]

  1. 对于任意整数 \(n\geqslant0,k\geqslant0\)

\[\omega_{N}^{k}=\omega_{2N}^{2k} \]

  1. 对于任意整数 \(n\geqslant0\)

\[\omega_{N}^{\frac{N}{2}}=-1 \]

  1. 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\)

\[\omega_{N}^{i+j}=\omega_{N}^{i}\times\omega_{N}^{j} \]

  1. 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\)

\[\omega_{N}^{ij}=(\omega_{N}^{i})^j \]

5.3 开始带入!

我们定义将\(<\omega_N^0,\omega_N^1, ... ,\omega_N^{N-1}>\)带入\(<a_0,a_1, ... ,a_{N-1}>\)的结果设为\(<A_0,A_1, ... ,A_{N-1}>\)
根据定义得

\[A_i=\sum_{j=0}^{N-1}a_j\times(\omega_N^i)^j \]

由性质5得

\[A_i=\sum_{j=0}^{N-1}a_j\times\omega_N^{ij} \]

我们把\(i\)的定义域从 \([0,N-1]\) 变成 \([0,\frac{N}{2}-1]\)

\[A_i=\sum_{j=0}^{N-1}a_j\times\omega_N^{ij} \]

\[A_{i+\frac{N}{2}}=\sum_{j=0}^{N-1}a_j\times\omega_N^{(i+\frac{N}{2})\times j} \]

奇偶分类

\[A_i=\sum_{j=0}^{\frac{N}{2}-1}a_{2j}\times\omega_N^{2ij}+\sum_{j=0}^{\frac{N}{2}-1}a_{2j+1}\times\omega_N^{i(2j+1)} \]

\[=\sum_{j=0}^{\frac{N}{2}-1}a_{2j}\times\omega_{\frac{N}{2}}^{ij}+\omega_N^i\sum_{j=0}^{\frac{N}{2}-1}a_{2j+1}\omega_{\frac{N}{2}}^{ij} \]

\[A_{i+\frac{N}{2}}=\sum_{j=0}^{\frac{N}{2}-1}a_{2j}\times\omega_N^{2(i+\frac{N}{2})j}+\sum_{j=0}^{\frac{N}{2}-1}a_{2j+1}\times\omega_N^{(i+\frac{N}{2})(2j+1)} \]

\[=\sum_{j=0}^{\frac{N}{2}-1}a_{2j}\times\omega_{\frac{N}{2}}^{ij}+\omega_N^{i+\frac{N}{2}}\sum_{j=0}^{\frac{N}{2}-1}a_{2j+1}\omega_{\frac{N}{2}}^{ij} \]

\[=\sum_{j=0}^{\frac{N}{2}-1}a_{2j}\times\omega_{\frac{N}{2}}^{ij}-\omega_N^i\sum_{j=0}^{\frac{N}{2}-1}a_{2j+1}\omega_{\frac{N}{2}}^{ij} \]

于是我们惊喜地发现,这可以分治做。
为什么需要奇偶分类呢?也许这就是傅里叶的伟大之处吧。

5.4 逆变换

那么FFT的逆变换怎么写呢?
令人惊讶的是,恰有

\[a_i=\frac{\sum_{j=0}^{N-1}A_j\times\omega_N^{-ij}}{N} \]

为什么它是正确的?

\[a_i=\frac{\sum_{j=0}^{N-1}A_j\times\omega_N^{-ij}}{N} \]

\[=\frac{\sum_{j=0}^{N-1}\sum_{k=0}^{N-1}a_k\times\omega_N^{jk}\times\omega_N^{-ij}}{N} \]

\[=\frac{\sum_{j=0}^{N-1}\sum_{k=0}^{N-1}a_k\times\omega_N^{j(k-i)}}{N} \]

\[=\frac{\sum_{j=0}^{N-1}\sum_{k=0}^{N-1}a_k[i==k]}{N} \]

\[=\frac{N\times a_i}{N} \]

\[=a_i \]


5.5 蝴蝶变换

我们将\(N=8\)的分治情况手动模拟一下,可以得到:

区间大小 \(id_0\) \(id_1\) \(id_2\) \(id_3\) \(id_4\) \(id_5\) \(id_6\) \(id_7\)
8 0 1 2 3 4 5 6 7
4 0 2 4 6 1 3 5 7
2 0 4 2 6 1 5 3 7
1 0 4 2 6 1 5 3 7

把这个表格用二进制描述

区间大小 \(id_0\) \(id_1\) \(id_2\) \(id_3\) \(id_4\) \(id_5\) \(id_6\) \(id_7\)
8 000 001 010 011 100 101 110 111
4 000 010 100 110 001 011 101 111
2 000 100 010 110 001 101 011 111
1 000 100 010 110 001 101 011 111

我们发现表格的第一行和最后一行二进制是反转的,这样我们就发现了\(FFT\)的非递归写法,代码如下:

#include<math.h>
#include<stdio.h>
#include<algorithm>
using namespace std;
const double pi=acos(-1.0);
int n,m,res=0,N=1,len,revers[2097160];
long long ans[2097160];
int i,j,k,l;
struct node{
  double x,y;
  node(double x=0,double y=0):x(x),y(y){}
  node operator*(const node &b){
    return node(x*b.x-y*b.y,x*b.y+y*b.x);
  }
  node operator+(const node &b){
    return node(x+b.x,y+b.y);
  }
  node operator-(const node &b){
    return node(x-b.x,y-b.y);
  }
}a[2097160],b[2097160],T,t,x,y;
void FFT(node *a,double flag){
  for(i=0;i<N;i++)
    if(i<revers[i])
      swap(a[i],a[revers[i]]);
  for(j=1;j<N;j<<=1){
    T=node(cos(pi/j),flag*sin(pi/j));
    for(k=0;k<N;k+=(j<<1)){
      t=node(1,0);
      for(l=0;l<j;l++,t=t*T){
        x=a[k+l],y=t*a[k+j+l];
        a[k+l]=x+y;
        a[k+j+l]=x-y;
      }
    }
  }
}
int main(){
  scanf("%d%d",&n,&m);
  n++;
  m++;
  for(i=0;i<n;i++)	
    scanf("%lf",&a[i].x);
  for(i=0;i<m;i++)
    scanf("%lf",&b[i].x);
  for(;N<max(n,m)<<1;N<<=1,len++);
  for(i=0;i<=N;i++)
    revers[i]=(revers[i>>1]>>1)|((i&1)<<(len-1));
  FFT(a,1);
  FFT(b,1);
  for(i=0;i<=N;i++)
    a[i]=a[i]*b[i];
  FFT(a,-1);
  for(i=0;i<=N;i++)
    ans[i]+=(long long)(a[i].x/N+0.5);
  for(;!ans[N]&&N;N--);
  N++;
  for(i=0;i<n+m-2;i++)
    printf("%lld ",ans[i]);
  printf("%lld\n",ans[n+m-2]);
  return 0;
}

5.6 精度问题

把上面的代码加入模操作后提交到P4245里去发现光荣\(\color{red}\text{WA}\)
然后就发现FFT有精度问题,那么如何避免呢?

5.6.1 原根

如果\(g\)\(0\) ~ \(\phi(p)-1\)在模\(p\)意义下正好遍历了\(1\) ~ \(p-1\)中与\(p\)互质的\(\phi(p)\)个数,那么称\(g\)\(p\)的原根。
当p为质数时,我们发现如果用\(g^\frac{p-1}{N}\)代替单位复根(记为\(g_N\)),它满足单位复根的所有性质:

  1. 对于任意整数 \(n\geqslant0,k\geqslant0,d\geqslant0\)

\[g_{N}^{k}=g_{Nd}^{kd} \]

  1. 对于任意整数 \(n\geqslant0,k\geqslant0\)

\[g_{N}^{k}=g_{2N}^{2k} \]

  1. 对于任意整数 \(n\geqslant0\)

\[g_{N}^{\frac{N}{2}}=-1 \]

  1. 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\)

\[g_{N}^{i+j}=g_{N}^{i}\times g_{N}^{j} \]

  1. 对于任意整数 \(n\geqslant0,i\geqslant0,j\geqslant0\)

\[g_{N}^{ij}=(g_{N}^{i})^j \]

5.6.2 能选的质数

一般情况下有三个质数可选:

  1. \(469762049=7\times 2^{26}+1\)
  2. \(998244353=119\times 2^{23}+1\)
  3. \(1004535809=749\times 2^{21}+1\)

\(p\)取上面几个质数时,\(g=3\)\(p-1\)中有很多\(2\)的因子,FFT中\(N\)又都是\(2\)的次幂,所以上面三个质数一定要记下来。

代码如下:

#include<stdio.h>
#include<algorithm>
using namespace std;
const long long p=998244353,g=3,invg=332748118;
int n,m,res=0,N=1,len,revers[2097160];
long long ans[2097160],a[2097160],b[2097160],T,t,x,y;
int i,j,k,l;
inline long long power(long long a,long long k,long long p){
  long long ans=1,t=a;
  for(;k;k>>=1,t=t*t%p)
    if(k&1)
      ans=ans*t%p;
  return ans;
}
void NTT(long long *a,long long flag){
  for(i=0;i<N;i++)
    if(i<revers[i])
      swap(a[i],a[revers[i]]);
  for(j=1;j<N;j<<=1){
    T=power(flag==1?g:invg,(p-1)/j/2,p);
    for(k=0;k<N;k+=(j<<1)){
      t=1;
      for(l=0;l<j;l++,t=t*T%p){
        x=a[k+l],y=t*a[k+j+l]%p;
        a[k+l]=(x+y)%p;
        a[k+j+l]=((x-y)%p+p)%p;
      }
    }
  }
}
int main(){
  scanf("%d%d",&n,&m);
  n++;
  m++;
  for(i=0;i<n;i++){
    scanf("%lld",a+i);
    a[i]=a[i]%p;
  }
  for(i=0;i<m;i++){
    scanf("%lld",b+i);
    b[i]=b[i]%p;
  }
  for(;N<max(n,m)<<1;N<<=1,len++);
  for(i=0;i<=N;i++)
    revers[i]=(revers[i>>1]>>1)|((i&1)<<(len-1));
  NTT(a,1);
  NTT(b,1);
  for(i=0;i<=N;i++)
    a[i]=a[i]*b[i]%p;
  NTT(a,-1);
  for(i=0;i<=N;i++)
    ans[i]=a[i]*power(N,p-2,p)%p;
  for(i=0;i<n+m-2;i++)
    printf("%lld ",ans[i]);
  printf("%lld\n",ans[n+m-2]);
  return 0;
}

5.7 换个角度看\(\bigoplus\)卷积

我们再回忆一下4.2节的内容,我们在做\(\bigoplus\)卷积时,其实可以做\(lg_N\)\(FFT\),然后又因为\(\omega_2^0=1\)\(\omega_2^1=-1\),就可以得到4.1节的结论了。

posted @ 2019-02-18 15:03  ττ  阅读(258)  评论(0编辑  收藏  举报