算法学习笔记(17): 快速傅里叶变换(FFT)

快速傅里叶变换(FFT)

有趣啊,都已经到NOI的难度了,救命

首先,我们先讲述一下前置知识。已经明白的读者请移步后文

虚数

定义z=a+bi,其中 a,bR  i=1

运算原则

(a+bi)+(c+di)=(a+c)+(b+d)i(a+bi)(c+di)=(acbd)+(ad+bc)i(a+bi)(c+di)=ac+bdc2+d2+bcadc2+d2

重要性质

eix=cosx+isinx

所以说,一个复数也可以写作 z=reiθ 的形式。

其中 r 为它的模,θ 为它的辐角。

于是两个复数相乘也就相当于模相乘,辐角相加:

z1z2=r1r2ei(θ1+θ2)

证明

我们通过欧拉公式在 0 处展开:

ex=1+x+x22!+x33!+cosx=1x22!+x44!x66!+sinx=xx33!+x55!x77!+

那么我们考虑如何把三者扯上关系呢?

由于已知 i2=1,i3=i 那么我们先考虑 eix

eix=1+ixx22!ix33!+

那么,很明显, eix=cosx+isinx

得证。

代码实现

在 C++ 中我们其实可以直接使用 std::complex<double>

文档可以参考 std::complex - cppreference.com

但是毕竟是 stl,其使用细节肯定没有那么顺手,而且实测很慢,所以建议手写复数模板。

考虑到实际中我们几乎不需要用到除法,所以,我们仅实现三则运算。

struct Complex {
    double real, imag;
    Complex() : real(0), imag(0) {}
    Complex(double re, double im) : real(re), imag(im) {};
    inline Complex operator + (const Complex & b) {
        return Complex(real + b.real, imag + b.imag);
    }
    inline Complex operator - (const Complex & b) {
        return Complex(real - b.real, imag - b.imag);
    }
    inline Complex operator * (const Complex & b) {
        return Complex(real * b.real - imag * b.imag, real * b.imag + imag * b.real);
    }
};

单位根

快速傅里叶变换的核心就是利用的单位根的一些独特的性质来快速实现的

单位根的定义:方程 zn=1 在复数范围内的 n 个根。

那么,不经过证明的给出,每一个根应该为 ei2kπn

这里我们记 ωn 为主 n 次单位根, ωnk=ei2kπn

举个例子,主 8 次单位根的 8 个值改写为形如 (r,θ) 的极坐标后,位置类似于下图:

三个引理

  • 消去定理:ωdndk=ωnk

证明:考虑展开即可:ωdndk=ei2dkπdn=ei2kπn=wnk

  • 折半引理:(ωnk+n2)2=(wnk)2=ωn2k

这个引理是快速傅里叶变化的核心

证明:也是考虑展开

ωnk+n2=ωnkωnn2=ωnk(ωnk)2=ωn2k=ωn2k

  • 求和引理:i=0n1(ωnk)i=0

证明

根据等比数列公式

i=0n1(ωnk)i=(wnk)n1wnk1=(wnn)k1wnk1=1k1wnk1=0

得证


多项式

(OI中)一般形式F(x)=a0+a1x+a2x2++anxn

上述多项式为一元多项式。

我们可以改写上式:i=0naixi

我们对于多项式运算定义如下:

A(x)=i=0naixiB(x)=i=0nbixi

  • 加法:

A(x)+B(x)=i=0n(ai+bi)xi

  • 乘法

一般情况下,我们可以通过补零的方式,将两个次数不同的多项式调整到次数相同。这里我们都补充到 n 的长度

ci=j=0iajbijA(x)B(x)=i=02ncixi

我们称这个系数向量 c 为向量 a,b 的卷积,记作 ab

表示方法

  • 系数表示

    它将一个多项式表示成由其系数构成的向量的形式

    例如 A=[a0,a1,a2,,an]T

    加法即为 A1+A2,直接相加即可。时间复杂度 O(n)

    乘法则做向量卷积,为 A1A2。一般来说,时间复杂度为 O(n2)

    如果给定 x 求值,则可以使用霍纳法则或者秦九昭算法。时间复杂度为 O(n)

  • 点值表示

    用至少 n 个多项式上的点来表示

    一般形式如 {(x0,A(x0)),(x1,A(x1),,(xn,A(xn))}

    进行运算是,一般要保证两个多项式在同一位置取值相同,即 xi 相同

    加法运算直接将两点坐标相加即可,时间复杂度为 O(n)

    乘法运算只需要将两点坐标相乘即可。时间复杂度为 O(n),太好了!

    如果我们需要 A(x) ,这个过程叫做插值,可以通过拉格朗日插值公式进行计算,复杂度为 O(n2),这里不展开讲述。

离散傅里叶变换(DFT)

DFT(Discrete Fourier Transform) 是快速傅里叶变换(FFT)的基础,也是快速数论变换(NTT)的基础

变换操作是对于一个向量而言(也就是多项式的系数表示法)

这个变换操作相当于求出这个多项式在 x 为单位根时的取值。

不妨设这个向量为 C=[c0,c1,c2,,cn1]T

我们定义一个变换公式

h(x)=i=0n1cixi

那么变换过后的序列为

[ h(ω0),h(ω1),h(ω2),,h(ωn1) ]T

其中 ω 代表主 n 次的单位根。

或者我们可以通过矩阵来表示:

[h(ω0)h(ω1)h(ω2)h(ωn1)]=[111111ωn1ωn2wn3ωnn11ωn2ωn2×2ωn2×3ωn2×(n1)1ωn(n1)×2ωn(n1)×3ωn(n1)×2ωn(n1)×(n1)][c0c1c2cn1]

所以说,暴力的算法为 O(n2),不如直接算。

值得一提的是这个矩阵是范德蒙德方阵,具有非常良好的性质。


对于上述序列,我们称形如 h(ωk) 的项为 k 次离散傅里叶级数。

我们将每一项展开,那么可以得到下图:

图片来自网络


这个时候,我们变换后的序列就是用点值表示的序列。

两个变换后的序列相乘,一一对应的乘法即可。

为什么?

变换后的序列其实就是函数在 w0,w1,,wn1 的点上的取值。

那么根据点值表示的多项式,两个多项式相乘,即是对应点相乘。

于是得到了两个多项式乘积的点值表示。

换句话来说,DFT 实际上所做的事情是插值,同理,IDFT 也相当于插值回去。


离散傅里叶逆变换(IDFT)

我们声称对于上述序列,其 k 次离散傅里叶变换后的的值恰为 n×ck

证明

我们用 g(ωk) 表示变换后的结果

g(ωk)=h(ω0)ω0k+h(ω1)ωk+h(w2)ω2k++h(ωn1)ω(n1)k=i=0n1h(ωi)ωik=i=0n1ωikj=0n1cjωij=i=0n1j=0n1ω(jk)icj=i=0n1cjj=0n1ω(jk)i

我们分类讨论一下:

  • j=k,那么此时 j=0n1ω(jk)y=j=0n1ω0=n。对于 ck 做出的贡献为 n

  • jk,将 k 看作常数,那么此时 j=0n1ω(jk)i=j=0n1ωj。依据上文中求和引理,其值为 0,也就是对 cj 做出的贡献为 0

综上所述,只有 ck 对于答案做出了 n 次贡献,所以 g(ωk)=n×ck

此部分到横线分割前可以适当的略过。

观察我们实际上做的事情:

[c0c1c2cn1]=[111111ωn1ωn2wn3ωnn11ωn2ωn2×2ωn2×3ωn2×(n1)1ωn(n1)×2ωn(n1)×3ωn(n1)×2ωn(n1)×(n1)]1[h(ω0)h(ω1)h(ω2)h(ωn1)]

是对于单位根的这个范德蒙德矩阵求逆后于原向量相乘。

我们可以直接对其求逆得到上述结论。

重新审视方阵,发现乘上一个范德蒙德方阵相当于带进了 n 个点进行求值,那么乘上其逆矩阵就应该是用 n 个点插值。

考虑拉格朗日插值:

f(x)=yijixxjxixj

比较显然的是:

(V1)i,k=[xk]jixxjxixj

F(x)=j(xxj)=j(xωj),考虑代数基本定理,那么知道 F(x)=xn1

现在我们需要 xn1xωj,考虑除法模拟展开:

fj(x)=F(x)xωj=xn1+ωjxn2+ω2jxn3++ω(n1)jx0

于是

(V1)i,j=[xi]fi(x)fi(ωj)=ωj(n1i)nωj(n1)=ωijn

也就得到了其逆矩阵:

1n[111111ωn1ωn2wn3ωn(n1)1ωn2ωn2×2ωn2×3ωn2×(n1)1ωn(n1)×2ωn(n1)×3ωn(n1)×2ωn(n1)×(n1)]

这就是 IDFT 的本质。


于是,我们可以得出两个多项式卷积的计算方法:

  • 将两个多项式改写为向量的形式,并分别对其做一次离散傅里叶变换 (DFT)

  • 将变换过后的两个序列相乘(点对点相乘)得出一个新的序列

  • 我们再对此序列做一次逆傅里叶变换,也就是将序列变为 g(ω0),g(ω1),g(ω2),,g(ω(n1))

    也就是 n×c0,n×c1,,n×cn1

    最后对于每一项除以 n 即可。

但是

最朴素的 DFT 变化的时间复杂度为 O(n2),三次变换还不如直接暴力计算……所以我们就需要快速傅里叶变换来优化了。

快速傅里叶变换(FFT)

FFT Fast-Fast-TLE

我们对于 f(x)=c0+c1x+c2x2++cn1xn1,分离其奇数项和偶数项,构造出另外两个向量

 feven(x)=c0+c2x+c4x2++cn2xn21fodd(x)=c1+c3x+c5x2++cn1xn21

那么,不难发现:

f(x)=feven(x2)+xfodd(x2)

也就是说

f(ωnk)=feven(ωn2k)+ωnkfodd(ωn2k)f(ωnk+n2)=feven(ωn2k+n)+ωnk+n2fodd(ωn2k+n)

补充一个点:

ωnn2=1

证明:考虑我们在上面画出的图中,ωnn2 所在的位置。就是 -1 了

再根据单位根的消去原理稍微化一下……

f(ωnk)=feven(ωn2k)+ωnkfodd(ωn2k)f(ωnk+n2)=feven(ωn2k)ωnkfodd(ωn2k)

于是,我们就可以递归分治了。

其时间为 T(n)=2T(n/2)+O(n)

故复杂度为 T(n)=O(nlogn)


其实我们还要考虑一点,我们要保证长度为 2k 才能保证可以正确的分治

因为只有长度相等的区间才能合并(考虑此时单位根才一样)。

所以说,我们要把两个多项式通过补 0 的方式补齐到 2k 项,合并之后就是 2k+1 项。

对于模板题:【模板】多项式乘法(FFT) - 洛谷

可以写出如下龟速代码:

#include <iostream>
#include <algorithm>
#include <vector>

struct Complex {
    double real, imag;
    Complex() : real(0), imag(0) {}
    Complex(double re, double im) : real(re), imag(im) {};
    inline Complex operator + (const Complex & b) { return Complex(real + b.real, imag + b.imag); }
    inline Complex operator - (const Complex & b) { return Complex(real - b.real, imag - b.imag); }
    inline Complex operator * (const Complex & b) { return Complex(real * b.real - imag * b.imag, real * b.imag + imag * b.real); }
};
typedef std::vector<Complex> Vector;

const double PI = acos(-1);

void FFT(Vector &v, int n, int inv) {
    if (n == 1) return; // 递归边界,只有一个元素,不做变换

    // 奇偶变化为两个向量
    int mid = n >> 1;
    Vector even(mid), odd(mid);
    for (int i(0); i < n; i += 2) {
        even[i >> 1] = v[i], odd[i >> 1] = v[i + 1]; 
    }
    // 递归操作 
    FFT(even, mid, inv), FFT(odd, mid, inv);

    // 进行合并操作
    // 定义基本 omega
    Complex omega(cos(PI * 2 / n), inv * sin(PI * 2 / n));
    // 当前旋转因子
    Complex w(1, 0);
    for (int i(0); i < mid; ++i, w = w * omega) {
        v[i] = even[i] + w * odd[i];
        v[i + mid] = even[i] - w * odd[i];
    }
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(0), std::cout.tie(0);

    int n, m;
    std::cin >> n >> m;

    // 获取最终的长度,必须是 2 的次幂,且比两个向量卷起来要长 
    int O(1);
    while (O <= m + n) O <<= 1;
    // std::cout << PI << " " << O << std::endl;

    Vector A(O), B(O);
    for (int i(0); i <= n; ++i) std::cin >> A[i].real; 
    for (int i(0); i <= m; ++i) std::cin >> B[i].real;

    FFT(A, O, 1);
    FFT(B, O, 1);

    // 我们单点相乘,然后进行逆变换,求出每一项的系数
    for (int i(0); i < O; ++i) A[i] = A[i] * B[i];
    FFT(A, O, -1);

    // 最后进行输出
    // 记得两个东西卷起来之后是 n + m 次的
    for (int i(0); i <= n + m; ++i) {
        // 这里是向上取整? 
        std::cout << (long long)(A[i].real / O + 0.5) << ' ';
    } std::cout << std::flush;
    return 0;
}

我不知道为什么网上这么多写递归版本的都是错的。

例如本题第一篇题解,其实并不是方法有问题,是他的写法错了。

链接:题解 P3803 【【模板】多项式乘法(FFT)】 - attack 的博客

例如知乎上的一篇文章,其在递归时的边界是有问题的。虽然说不影响正确性……

链接:快速傅里叶变换 - 星夜

当然,还是有代码正确,但是代码……

链接:FFT-快速傅里叶变换 - heartbeats

有一个不算优化的优化。考虑每一层所需要的空间是固定的。所以考虑预先分配 O(nlogn) 的空间,然后直接使用即可。

蝶形优化

在运行FFT的递归版本时,可以观察到递归树如下:

如果我们把初始向量 a 中的元素按照其在叶子中出现的次序进行安排,就可以对递归过程进行追踪,不过是自底向上,而非自顶向下。

回顾 FFT 合并时的代码:

for (int i(0); i < mid; ++i, w = w * omega) {
    v[i] = even[i] + w * odd[i];
    v[i + mid] = even[i] - w * odd[i];
}

在这个循环中,我们只用到了两个值:even[i], odd[i],得到了两个值:v[i], v[i + mid]。我们利用一种图示表示:

我们称这个操作为蝴蝶操作

那么整个计算FFT的过程可以通过此图表示

这个地方确实比较绕,难以通过语言表述,不过通过上面 4 张图,请读者停下稍微悟一悟。


其实从第一个递归的图中,我们已经可以发现初始值的规律了:

我们通过观察其二进制得出答案:下标的二进制恰好和目标二进制互为倒序

一共只有 log(n) 位!也就是保证所有数拥有一样多的比特位。

也就是说当 n=8rev[3] = rev[0b011] = 0b110 = 6, log8=4

我们考虑可以通过 DP 在 O(n) 内实现。

假设我们已经处理完了 1n1 的所有 rev

考虑 n=(abcd)2,此时我们已经知道了 rev[(0abc)2]=(cba0)2,需要 rev[(abcd)2]=(dcba)2

通过瞪眼法,我们可以轻易的得出

rev[(abcd)2]=(dcba)2=(rev[(0abc)2]>>1) | ((d&1)<<3)

改写为递推式即是:

dp[x] = (dp[x>>1] >> 1) | ((x&1) << (log2(n) - 1))


参考代码

#include <complex>
#include <iostream>
#include <algorithm>
#include <vector>

struct Complex {
    double real, imag;
    Complex() : real(0), imag(0) {}
    Complex(double re, double im) : real(re), imag(im) {};
    inline Complex operator + (const Complex & b) { return Complex(real + b.real, imag + b.imag); }
    inline Complex operator - (const Complex & b) { return Complex(real - b.real, imag - b.imag); }
    inline Complex operator * (const Complex & b) { return Complex(real * b.real - imag * b.imag, real * b.imag + imag * b.real); }
};

typedef std::vector<Complex> Vector;

const double PI = acos(-1);
int O(1), logO(0);

void FFT(std::vector<int> &rev, Vector &v, int inv) {
    for (int i(0); i < O; ++i) {
        if (i < rev[i]) std::swap(v[i], v[rev[i]]);
    }

    // 第 log(k) 次合并,一共logO次 
    // 合并之后区间的长度为 k
    for (int k(1); k < O; k <<= 1) {
        Complex omega(cos(PI / k), inv * sin(PI / k));
        for (int i(0); i < O; i += (k<<1)) { // 处理行 
            Complex w(1, 0);
            for (int j = 0; j < k; ++j, w = w * omega) {
                Complex s = v[i + j], t = v[i + j + k] * w;
                v[i + j] = s + t, v[i + j + k] = s - t; 
            }
        }
    }

    if (inv == -1) for (int i(0); i < O; ++i) v[i].real /= O;
}


int main() {
    int n, m;
    std::cin >> n >> m;

    while (O <= n + m) O <<= 1, ++logO;

    Vector A(O), B(O);
    for (int i(0); i <= n; ++i) std::cin >> A[i].real; 
    for (int i(0); i <= m; ++i) std::cin >> B[i].real;

    std::vector<int> rev(O);
    for (int i(0); i < O; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (logO - 1));

    FFT(rev, A, 1), FFT(rev, B, 1);

    for (int i(0); i < O; ++i) A[i] = A[i] * B[i];
    FFT(rev, A, -1);

    for (int i(0); i <= n + m; ++i) {
        std::cout << (long long)(A[i].real + 0.1) << ' '; // 向上取整 
    } std::cout << std::flush;
    return 0;
}

那么恭喜你,大概是明白了 FFT 了吧!


快速数论变换(NTT)

快速数论变换我们可以类比快速傅里叶变换。只是将复数优化为了取模实数。

在模 p 意义下,会存在 a(p1)1(modp)

于是我们找到 Zp 中的一个原根 g

可以参考:算法学习笔记(11): 原根 - jeefy - 博客园

满足:

x[1,φ(p)),gx1(modp)gφ(p)1(modp)

的数 g 即是 Zp 中的一个原根。

那么如何构造一个满足单位根性质的数

v=φ(p),考虑:

wn=gvn

显然,有 gn1(modp)

后文中为了方便,就不写后面的 (modp) 了。

如果 n 不能整除 v 怎么办?凉拌

这也就是我们为什么把模数设为 998244353=119223+1 或者 167772161=5225+1

那么我们来证明其满足作为单位根需要的三个性质:

  • 消去wdndk=wnk

    证明:考虑展开即可:wdndk=gvdndk=gnk

  • 折半(wnk+n2)2=wn2k+n=wn2k=wn2k

  • 求和i=0n1(ωnk)i=0

    证明:也考虑等比数列:

    i=0n1(ωnk)i=(wnk)n1wnk1=(wnn)k1wnk1=1k1wnk1=0

    这过程怎么一模一样……

那么我们可以将单位根替换,进行变换了!


说实话,代码几乎是一模一样的……

const int MOD = 998244353, g = 3, ig = 332748118;
int O(1), logO(0);

typedef vector<long long> vec;
typedef vector<int> ivec;

void NTT(ivec &rev, vec &v, int inv) {
    for (int i(0); i < O; ++i) {
        if (rev[i] < i) std::swap(v[i], v[rev[i]]);
    }

    // 第 log(k) 次合并
    for (int k(1); k < O; k <<= 1) {
        long long omega = qpow((long long)(~inv ? g : ig), (MOD - 1) / (k << 1), MOD);
        for (int i(0); i < O; i += (k << 1)) {
            long long w(1);
            for (int j(0); j < k; ++j, w = w * omega % MOD) {
                long long s = v[i + j], t = w * v[i + j + k];
                v[i + j] = (s + t) % MOD, v[i + j + k] = ((s - t) % MOD + MOD) % MOD;
            }
        }
    }

    if (inv == -1) {
        for (int iv(qpow((long long)O, MOD - 2, MOD)), i(0); i < O; ++i) {
            v[i] = v[i] * iv % MOD;
        }
    }
}

而且由于没有了复数,常数会小很多(比FFT快)


FFT 的其他

  • 循环卷积:这里卷积的结果实际上是 ck=i+jk(modn)aibj

  • 任意模数 NTT:模 p 意义下结果值域大概是 O(n2p) 级别的,所以需要三个 >n2p 的模数以及 CRT 完成。

  • 拆系数FFT:也就是把值放在 M 进制下运算,那么值域就减到 O(np) 级别了,非常优秀。

posted @   jeefy  阅读(850)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示