多项式乘法

我们知道,多项式可以表示成:

A=i=0naixi
的形式。
对于两个多项式A(x)B(x),我们可以计算乘积AB
AB=i=0sizeAj=0sizeBaibjxi+j

但是,这样算是O(sizeAsizeB)的,太慢了,怎么办?
我们需要换一条思路。

首先,我们得知道一个东西:多项式的点值表示法
我们把上面的称为多项式的系数表示法,而点值表示法就是:
A多项式的次数为n,则任取n个不相同的x0,x1,,xn,求出A多项式的A(x0),A(x1),,A(xn)。记为:

<(x0,A(x0)),(x1,A(x1)),,(xn,A(xn))>
显然,一个有n+1个点的点对唯一表示一个n次多项式。

对于一个点值表示法下多项式

<(x0,A(x0)),(x1,A(x1)),,(xn,A(xn))>
<(x0,B(x0)),(x1,B(x1)),,(xn,B(xn))>
它们的乘积是
<(x0,A(x0)B(x0)),(x1,A(x1)B(x1)),,(xn,A(xn)B(xn))>
可以看出点值表示法的多项式相乘是O(n)的。

等等,我们好像找到了一个突破口!
为啥不把原来的系数表示法下的多项式转化成点值表示法呢?
仔细想一想:系数表示法与点值表示法互相转换,这个步骤好像是O(n2)的。
FFT(快速傅里叶变换)就是为了优化这个O(n2)

PS:对于O(n2)的点值表示法转化成系数表示法可以看百度百科中关于插值法的介绍

FFT(快速傅里叶变换)

如果未特别说明,那么下面的多项式次数将是2k1的形式。
如果不是关键部分的公式或定理,不提供证明,自己出门右转百度。

首先介绍两个概念:
DFT(离散傅里叶变换)是将多项式由系数表示法转化成点值表示法;
IDFT(离散傅里叶逆变换)是将多项式由点值表示法转化成系数表示法;
而FFT就是上述两种变换的优化。

DFT部分

前置技能:
下面的内容将会提到复数,不会的可以参考百度百科中关于复数的介绍

为了介绍FFT中的DFT部分,首先要介绍的是一个概念:单位根
单位根:若有

zn=1
此时将z称为n次单位根。
若有zR,显然,z可以等于1,如果n是偶数,则z还可以等于1
我们把范围扩大到zC,那么,我们可以得到n个复数,它们将复平面上的单位圆等分成n份。

为了表示n次单位根,我们引入一个公式。
欧拉公式:

ein=cosn+isinn

如果我们令:

ωn=e2πi/n
那么,n次单位根就可以表示成ωn0,ωn1,,ωnn1,它们的n次方显然都是1

下面是关于ωn的两条性质:(都是在n为偶数的情况下)

(1)ωnn/2=e(2πi/n)(n/2)=eπi=cosπ+isinπ=1

(2)ωn2=e22πi/n=ωn/2

下面,我们进入正题:DFT的求法
在这里,我们令多项式次数为n1,那么我们可以用点值表示成

<(ωn0,A(ωn0),(ωn1,A(ωn1)),,(ωnn1,A(ωnn1))>

额……这时间复杂度好像并没有减少……
别急,我们来看A(ωnk)能够表示成什么。

(3)A(ωnk)=i=0n1aiωnki(4)=i=0n/21a2iωn2ki+i=0n/21a2i+1ωn2ki+k(5)=i=0n/21a2iωn/2ki+i=0n/21a2i+1ωn/2kiωnk(6)=i=0n/21a2iωn/2ki+ωni=0n/21a2i+1ωn/2ki

我们来分别看一看这神奇的步骤。
(3)步骤就是将ωnk带入原来的A多项式。
(4)步骤就是将原多项式拆成两个部分,按奇偶分类。
(5)步骤用到了上面提到的性质(2)
(6)步骤就是上面式子的后半部分提出公因数。

有了这个等式,我们就可以分治+递归解决DFT了。
算法步骤:

  • 对当前的多项式(一个数组)系数进行奇偶分类;
  • 递归算出偶数部分的数组的anse和奇数部分的数组的anso
  • 这个多项式的ans=anse+ωnanso

但是这个的常数好像很大啊?能不能减少一点呢?
上面的性质(1)给了我们提示:

ωnn/2+k=ωnn/2ωnk=ωnk

在算k<n2时,可以顺便把kn2的情况也算出来。

常数减小了一半!但是还是很大啊!
递归版的程序一般比非递归版慢,为啥不用非递归版呢?

算法核心就是奇偶分类,分来分去最后分到了哪里?我们来研究研究。
显然,一个序列原来是0,1,2,3,4,5,6,7,最终变成0,4,2,6,1,5,3,7
把它们的二进制列出来:

000,001,010,011,100,101,110,111000,100,010,110,001,101,011,111

其中,上面是位置,下面是这个位置对应的数。

把上面的数翻转,好像就是下面的数!
没错,只需要计算一下每个数的二进制翻转后的结果,就能得到一个数最终对应的位置,也就能实现非递归版了。

代码:

int dft_fast(complex* ar,int len)
{
  for(register int i=0; i<len; ++i)
    {
      if(rev[i]<i)
        {
          std::swap(ar[rev[i]],ar[i]);//交换一个位置和它的翻转后位置
        }
    }
  for(register int i=2; i<=len; i<<=1)//i代表当前序列的长度
    {
      complex wn(cos(2*pi/i),sin(2*pi/i));//omega_n
      for(register int j=0; j<len; j+=i)//j代表序列的起始位置
        {
          complex w(1,0);//下面代表omega_n^k
          for(register int k=0; k<(i>>1); ++k)//枚举i次单位根的每一种取值
            {
              complex x=ar[j+k],y=w*ar[j+k+(i>>1)];
              ar[j+k]=x+y;//合并操作,将两边合并成一个点值表示法下的多项式
              ar[j+k+(i>>1)]=x-y;
              w=w*wn;
            }
        }
    }
  return 0;
}

IDFT部分

回顾上面的DFT部分,仔细思考一下,它本质就是在求:

{a0(ωn0)0+a1(ωn0)1++an1(ωn0)n1=A(ωn0)a0(ωn1)0+a1(ωn1)1++an1(ωn1)n1=A(ωn1)a0(ωnn1)0+a1(ωnn1)1++an1(ωnn1)n1=A(ωnn1)

其中,给定了a0,a1,,an1以及ωn0,ωn1,,ωnn1,求A(ωn0),A(ωn1),,A(ωnn1)的值。

用矩阵表示如下:

(7)[(ωn0)0(ωn0)1(ωn0)n1(ωn1)0(ωn1)1(ωn1)n1(ωnn1)0(ωnn1)1(ωnn1)n1][a0a1an1]=[A(ωn0)A(ωn1)A(ωnn1)]

我们令:

V=[(ωn0)0(ωn0)1(ωn0)n1(ωn1)0(ωn1)1(ωn1)n1(ωnn1)0(ωnn1)1(ωnn1)n1]

那么IDFT的本质就是求V矩阵的逆矩阵。

考虑下面这个矩阵:

D=[(ωn0)0(ωn0)1(ωn0)n1(ωn1)0(ωn1)1(ωn1)n1(ωn(n1))0(ωn(n1))1(ωn(n1))n1]

那么我们令E=DV,则:

Ei,j=k=0n1Di,kVk,j=k=0n1(ωni)k(ωnk)j=k=0n1ωnk(ji)

显然:

Ei,j={0(ij)n(i=j)

因此,

1nDV=1nE=In

(7)式两边同时左乘1nD,可得
[a0a1an1]=1n[(ωn0)0(ωn0)1(ωn0)n1(ωn1)0(ωn1)1(ωn1)n1(ωn(n1))0(ωn(n1))1(ωn(n1))n1][A(ωn0)A(ωn1)A(ωnn1)]

这就相当于把DFT中ωnk都换成ωnk

FFT总代码

int fft(complex* ar,int len,int op)
{
  for(register int i=0; i<len; ++i)
    {
      if(rev[i]<i)
        {
          std::swap(ar[rev[i]],ar[i]);
        }
    }
  for(register int i=2; i<=len; i<<=1)
    {
      complex wn(cos(2*pi/i),sin(2*pi*op/i));//只有这里较DFT代码有变动
      for(register int j=0; j<len; j+=i)
        {
          complex w(1,0);
          for(register int k=0; k<(i>>1); ++k)
            {
              complex x=ar[j+k],y=w*ar[j+k+(i>>1)];
              ar[j+k]=x+y;
              ar[j+k+(i>>1)]=x-y;
              w=w*wn;
            }
        }
    }
  if(op==-1)
    {
      for(register int i=0; i<len; ++i)
        {
          ar[i].r/=len;
        }
    }
  return 0;
}

多项式乘法模板

#include <cstdio>
#include <cmath>
#include <algorithm>

const int maxn=100000;
const double pi=acos(-1);

struct complex
{
  double r,i;

  complex(double r_=0,double i_=0)
  {
    r=r_;
    i=i_;
  }

  complex operator +(const complex &other)
  {
    return complex(r+other.r,i+other.i);
  }

  complex operator -(const complex &other)
  {
    return complex(r-other.r,i-other.i);
  }

  complex operator *(const complex &other)
  {
    return complex(r*other.r-i*other.i,r*other.i+i*other.r);
  }
};

complex a[maxn<<2],b[maxn<<2],c[maxn<<2];
int rev[maxn<<2],n,m;

int fft(complex* ar,int len,int op)
{
  for(register int i=0; i<len; ++i)
    {
      if(rev[i]<i)
        {
          std::swap(ar[rev[i]],ar[i]);
        }
    }
  for(register int i=2; i<=len; i<<=1)
    {
      complex wn(cos(2*pi/i),sin(2*pi*op/i));
      for(register int j=0; j<len; j+=i)
        {
          complex w(1,0);
          for(register int k=0; k<(i>>1); ++k)
            {
              complex x=ar[j+k],y=w*ar[j+k+(i>>1)];
              ar[j+k]=x+y;
              ar[j+k+(i>>1)]=x-y;
              w=w*wn;
            }
        }
    }
  if(op==-1)
    {
      for(register int i=0; i<len; ++i)
        {
          ar[i].r/=len;
        }
    }
  return 0;
}

int main()
{
  scanf("%d%d",&n,&m);
  for(register int i=0; i<=n; ++i)
    {
      scanf("%lf",&a[i].r);
    }
  for(register int i=0; i<=m; ++i)
    {
      scanf("%lf",&b[i].r);
    }
  n=n+m;
  int l=0;
  m=1;
  while(m<=n)
    {
      ++l;
      m<<=1;
    }
  for(register int i=0; i<m; ++i)
    {
      rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
    }
  fft(a,m,1);
  fft(b,m,1);
  for(register int i=0; i<m; ++i)
    {
      c[i]=a[i]*b[i];
    }
  fft(c,m,-1);
  for(register int i=0; i<n; ++i)
    {
      printf("%d ",(int)(c[i].r+0.5));
    }
  printf("%d\n",(int)(c[n].r+0.5));
  return 0;
}