FFT笔记

FFT笔记

前言:
这个算法对于我来讲比较抽象、高深,因为里面涉及了一些复数等一些对我而言很难很难的知识。
终于,花了几节文化课的时间冥思苦想,终于算是搞懂一点了。所以我决定趁脑子清醒的时候记录下来。
与其他文章不同的是,本文可能没有太多的公式证明,主要是以通俗易懂的方式去讲解,也是为了方便大家(包括以后可能忘记的我)更好地理解。

建议:在电脑前放好演算纸与笔。

注:本篇文章是我这个小蒟弱写的,真正的dalao请看个玩笑便好,不必争论对错(但是欢迎指出文章存在的小错误)。

FFT 有什么用

快速傅里叶变换(FFT)可以在\(O(n\log n)\)的时间解决两个多项式的相乘问题。

为什么要用 FFT

我们先假设一个要计算多项式乘多项式的情景:

\(A(x)=a_0+a_1x+a_2x^2+a_3x^3+\dots\)
\(B(x)=b_0+b_1x+b_2x^2+b_3x^3+\dots\)

假设已经得到结果多项式为\(C(x)=c_0+c_1x+\dots\)

根据我们在数学中的做法可以得到:

  • \(\;x\)的指数是由\(A(x),B(x)\)中任意两项的指数相加得到的
  • \(\;c_n\)的组成是由能得到\(x^n\)的项的系数相乘得到的

如果不能理解,建议推一下下面这个式子:

\[(2x^2+x)(x+1)=2x^3+3x^2+x \]

那么多项式的乘法用计算机实现,第一个想到的就是\(O(n^2)\)的朴素算法(两层循环模拟)。针对于这种普通的多项式乘法,我们把它称作系数表示法。因为一旦能确定所有系数,多项式也随之确定。

如果我们把多项式和函数图像融合在一起,会发现一个多项式\(A(x)=a_0+a_1x+a_2x^2+a_3x^3+\dots+a_nx^n\)可以用\(y=a_0+a_1x+a_2x^2+a_3x^3+\dots+a_nx^n\)的函数图像所表示。
那根据确定函数的性质,如果我们可以知道\(n+1\)\(y\)值,那么我们可以唯一地确定这个多项式。像这种方法,我们称之为点值表示法

有趣的是,点值表示法时间复杂度是\(O(n)\)的!实现:因为\(C(x)=A(x)\times B(x)\),所以\(O(n)\)暴力枚举\(x_i\)即可。

相信此时的你蠢蠢欲动:如果我们将系数表示法转换为点值表示法,那不就可以大大降低我们多项式相乘的时间复杂度了吗?!

……遗憾的是,很难转换。无论是朴素算法\(O(n^2)\)转换,或是——可能你还不知道的拉格朗日插值法转换,时间复杂度没有降下来,仍然是\(O(n^2)\)

那现在思路很清晰,我们要想将多项式相乘速度提快,就必须将系数表示法转换为点值表示法这一大关打下来!(也要完成逆转换)

在世界第一台计算机(1946年)横空出世的139年前(1807年),一位伟大的数学家信誓旦旦——

傅里叶:这个我会!

DFT(离散FFT/朴素FFT)

这个办法就是:用\(n\)个模长为\(1\)复数代替点值表示法中\(n\)\(x\)(下文默认\(n\)\(2\)的幂)。

你可能觉得复数特别厉害,高中知识,溜了溜了。但是我们理不理解复数并没有直接关系

复数:
“如果你学过了,可以跳过。
如果你不会复数,可以当做是向量。
如果你不会向量,可以看成是平面直角坐标系上的一个点。
如果你不会平面直角坐标系,可以看成是c++中的pair容器。
如果你还是什么都不会,…………(出门右拐先学习一下平面直角坐标系吧……”

复数有一个实部还有一个虚部,类似于一个向量(或点)的横纵坐标。例如复数\(3+2i\)\(3\)是实部,\(2\)是虚部,\(i=\sqrt {-1}\)。可以想象成向量\((3,2)\)或点\((3,2)\)

复数的运算规则是:模长相乘,幅角相加。如果你只是想学FFT,记住幅角相加就好了。

那为什么要用复数呢?有趣在于——它是一种数,可以带入多项式\(A(x)\)中去,显然你不能把一个向量(或点)带入一个多项式。

更有趣的是,c++提供了复数的模板!
头文件:#include<complex>
定义:complex<double> a[1000],b,c;
运算:直接加减乘除

上面说了要找\(n\)个模长为\(1\)的复数,可是这可不是乱找的。傅里叶精心地挑选了\(n\)个点,而这\(n\)个点,实在平面直角坐标系中,将一个单位圆平均分成\(n\)等分,这\(n\)个点的横坐标为实部、纵坐标为虚部,便可以构成\(n\)个虚数。

详见图:

\((1,0)\)开始,然后逆时针从\(0\)开始编号,第\(k\)个点记作虚数\(\omega_n^k\)。还记得模长相乘,幅角相加吗?所以\(\omega_n^k\)\(\omega_n^1\)\(k\)次方,因此\(\omega_n^1\)就是单位根。

根据每个复数的幅角,可以确定这个向量或点的位置。\(\omega_n^k\)对应的向量或点坐标为\((\cos{\frac{k}{n}2\pi},\sin{\frac{k}{n}2\pi})\),那么这个复数也就为\(\cos{\frac{k}{n}2\pi}+i\times \sin{\frac{k}{n}2\pi}\)

那么把\(n\)\(\omega_n^0,\omega_n^1,\dots,\omega_n^{n-1}\)带入多项式,就可以得到特殊的点值表示法。傅里叶开心地将这种点值表示称为离散傅里叶变换

为什么要使用单位根代入

这肯定是有讲究的,因为这里有一些有趣的性质。
相关证明就不写了,直接上结论:

将多项式\(A(x)\)的傅里叶变换结果作为多项式\(B(x)\)的系数,再将单位根的倒数作为\(x\)代入\(B(x)\),得到的每个数再\(\div n\),得到的正是\(A(x)\)的各项系数!

总而言之,这也顺带完成了将点值表示法逆转换成系数表示法的任务,这也是离散傅里叶变换神奇的特殊性质。

FFT

傅里叶发明了如此高深的DFT,完成了我们的主要任务——完成转换与逆转换。但是DFT仍是朴素的\(O(n^2)\)……

傅里叶:我都没见过计算机,为什么要优化时间复杂度……

但是,运用信息的思维,可以立马想到分治!因此快速傅里叶变换由此应运而生!

显然也是可以证明的,但是因为太懒、掌握不牢固,所以决定待以后再证。

分治可以用dfs实现,边界条件为\(n=1\)时,直接return。

递归实现FFT

code:

#define cp complex<double>
cp omega(int n, int k)
{
    return cp(cos(2 * PI * k / n), sin(2 * PI * k / n));
}
void fft(cp  *a, int n, bool inv)
{
    if(n == 1) return;
    static cp buf[N];
    int m = n / 2;
    for(int i = 0; i < m; i++) //将每一项按照奇偶分为两组
    { 
	buf[i] = a[2 * i];
	buf[i + m] = a[2 * i + 1];
    }
    for(int i = 0; i < n; i++)
	a[i] = buf[i];
    fft(a, m, inv); //递归处理两个子问题
    fft(a + m, m, inv);
    for(int i = 0; i < m; i++) //枚举x,计算A(x)
    { 
	cp x = omega(n, i); 
	if(inv) x = conj(x); 
	//conj是一个自带的求共轭复数的函数,精度较高。当复数模为1时,共轭复数等于倒数
	buf[i] = a[i] + x * a[i + m]; //根据之前推出的结论计算
	buf[i + m] = a[i] - x * a[i + m];
    }
    for(int i = 0; i < n; i++)
	a[i] = buf[i];
}

inv表示单位根\(\omega_n^1\)取不取倒数。
然鹅这个只是1.0版本,让我们学一些优化吧!

优化实现FFT

迭代版本非递归FFT

在递归时我们不断分组分成两侧,观察一下有没有什么规律:

\[初始位置:0\;1\;2\;3\;4\;5\;6\;7\\ 第一轮后:0\;2\;4\;6|1\;3\;5\;7\\ 第二轮后:0\;4|2\;6|1\;5|3\;7\\ 第三轮后:0|4|2|6|1|5|3|7 \]

这些“\(|\)”表示分割线。

明显地(或者说非常不明显地),将它们转换成二进制后有镜面效果

\[原先:000,001,010,011,100,101,110,111\\ 现在:000,100,010,110,001,101,011,111 \]

这里的镜面反转指的是同一个位置进行了变化。

那么我们可以把每位数放在最后一个位置上,然后不断向上还原,同时求出点值表示。

代码部分建议跟着代码手推一下,因为全都是位运算,有点难理解。

镜面操作单独code1:

for(int i = 0; i < n; i++)
{
    int t = 0;
    for(int j = 0; j < lim; j++)
        if((i >> j) & 1) t |= (1 << (lim - j - 1));
    if(i < t) swap(a[i], a[t]); 
    // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
}

镜面操作单独code2(这个代码可以不用到lim变量):

for(int i=0,j=0;i<n;++i)
{
    if(i>j) swap(a[i],a[j]);
    for(int l=n>>1;(j^=l)<l;l>>=1);
}

融合进FFT的镜面操作code(调用的是操作1):

cp a[N], b[N], omg[N], inv[N];
int lim;
void init()
{
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++)
    {
    	omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
    	inv[i] = conj(omg[i]); 
        // conj这一行等价于 inv[i]=cp(cos(2*PI*i/n),-sin(2*PI*i/n));
    }
}
void fft(cp *a, cp *omg)
{
    for(int i = 0; i < n; i++) // 这一个for就是用位运算进行二进制的镜面操作
    {
    	int t = 0;
    	for(int j = 0; j < lim; j++)
    	    if((i >> j) & 1) t |= (1 << (lim - j - 1));
    	if(i < t) swap(a[i], a[t]); 
        // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
    }
    static cp buf[N];
    for(int l = 2; l <= n; l *= 2)
    {
    	int m = l / 2;
    	for(int j = 0; j < n; j += l)
    	    for(int i = 0; i < m; i++)
            {
                buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
        	buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
                // 为什么是omg[n / l * i]?因为这是在用递推模拟递归,只有这样写才符合上面递归代码中n的变化
	    }
    	for(int j = 0; j < n; j++)
    	    a[j] = buf[j];
    }
}

如果我们预先处理好\(\omega_n^k\)\(\omega_n^{-k}\)并存入omg,inv数组里,那我们就可以根据需求在fft函数中传入不同的数组。

蝴蝶操作

这个优化听起来贼nb,但我们可以从本质上思考一下:buf[]有何用?
是为了做这一件事:

a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
a[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];

同时不能使这两行互相干涉,所以我们才需要buf[]。
但是如果我们用一个临时变量代替:

cp t = omg[n / l * i] * a[j + i + m];
a[j + i + m] = a[j + i] - t;
a[j + i] = a[j + i] + t;

就没有任何问题啦!!
全样长这样:

cp a[N], b[N], omg[N], inv[N];
int lim;
void init()
{
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++)
    {
    	omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
    	inv[i] = conj(omg[i]);
        // conj这一行等价于 inv[i]=cp(cos(2*PI*i/n),-sin(2*PI*i/n));
    }
}
void fft(cp *a, cp *omg)
{
    for(int i = 0; i < n; i++)
    {
    	int t = 0;
    	for(int j = 0; j < lim; j++)
    	    if((i >> j) & 1) t |= (1 << (lim - j - 1));
    	if(i < t) swap(a[i], a[t]);
        // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换)
    }
    for(int l = 2; l <= n; l *= 2)
    {
        int m = l / 2;
        for(cp *p = a; p != a + n; p += l)
            for(int i = 0; i < m; i++)
            {
                cp t = omg[n / l * i] * p[i + m];
                p[i + m] = p[i] - t;
                p[i] += t;
            }
    }
}

现在,这个终极无敌优化版FFT就比原递归FFT快得多了。


后记&版题

终于!笔记写完啦!

本篇文章借鉴了一些内容,是源于这位dalao这篇博客,非常感谢!!

下面我也贴一个板子吧(^^)!

code:

#include<bits/stdc++.h>
using namespace std;

#define ll long long
#define cp complex<double>
#define rp(i,o,p) for(ll i=o;i<=p;++i)
#define pr(i,o,p) for(ll i=o;i>=p;--i)

const ll MAXN=1e6+5;
const double PI=acos(-1.0);

ll n=1;
char s1[MAXN],s2[MAXN];
ll la,lb,ans[MAXN];
cp a[MAXN],b[MAXN],omg[MAXN],inv[MAXN];

void init()
{
    for(ll i=0;i<n;++i)
    {
        omg[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
        inv[i]=conj(omg[i]);
        // <=> inv[i]=cp(cos(2*PI*i/n),-sin(2*PI*i/n));
    }
}

void fft(cp *a,cp *omg)
{
    ll lim=0;
    while((1<<lim)<n) ++lim;
    for(ll i=0;i<n;++i)
    {
        ll t=0;
        for(ll j=0;j<lim;++j)
            if(((i>>j)&1)) 
                t|=(1<<(lim-j-1));
        if(i<t)
            swap(a[i],a[t]);
    }
    for(ll l=2;l<=n;l<<=1)
    {
        ll m=l>>1;
        for(cp *p=a;p!=a+n;p+=l)
        {
            for(ll i=0;i<m;++i)
            {
                cp t=omg[n/l*i]*p[i+m];
                p[i+m]=p[i]-t;
                p[i]+=t;
            }
        }
    }
    
}

int main()
{
    scanf("%s%s",s1,s2);
    la=strlen(s1),lb=strlen(s2);
    while(n<la+lb) n<<=1;

    for(ll i=0;i<la;++i)
        a[i].real(s1[la-i-1]-'0');
    for(ll i=0;i<lb;++i) 
        b[i].real(s2[lb-i-1]-'0');
    
    init();
    fft(a,omg);
    fft(b,omg);

    for(ll i=0;i<n;++i)
        a[i]*=b[i];

    fft(a,inv);

    for(ll i=0;i<n;++i)
    {
        ans[i]+=(ll)(a[i].real()/n+0.5);
        ans[i+1]+=ans[i]/10;
        ans[i]%=10;
    }
    ll i;
    for(i=la+lb-1;ans[i]==0&&i>=0;--i)
        if(i==0)
            putchar('0'),i=-1;
    while(i>=0)
        putchar('0'+ans[i--]);
    putchar('\n');

    return 0;
}
posted @ 2023-04-27 21:05  WerChange  阅读(76)  评论(2编辑  收藏  举报