FFT学习笔记

快速傅里叶变换

多项式

定义

不严谨定义:形如 \(f(x) = \sum \limits _{i=0}^{n} a_ix^i\) 的式子为多项式。

定义(from OIWiki):对于求和式 \(\sum a_nx^n\),如果是有限项相加,称为多项式,记作 \(f(x)=\sum \limits_{n=0}^m a_nx^n\)

次数:对于多项式 \(F(x) = \sum \limits _{i=0}^{n} a_ix^i\),该多项式的次数为 \(n\)。表示为 \(\text{degree} (F) = n\)

次数界:任何一个大于多项式次数的整数。

基本运算

系数表示法:

\[\mathbf{a} = [a_1, a_2,\cdots, a_n] ^\top \]

加法

\[\mathbf{c} = \mathbf{a} + \mathbf{b} \]

\[c_i = \sum \limits_{i = 0}^{n} (a_i + b_i) \]

时间复杂度 \(O(n)\)

乘法

\[A(x) = \sum \limits_{i = 0}^n a_ix^i \ \ B(x) = \sum \limits_{i = 0}^n b_ix^i \]

\[\text{degree} (C) = \text{degree} (A) + \text{degree} (B) \]

\[c_i = \sum \limits_{j = 0}^i a_j b_{i - j} \]

记为 \(\mathbf{c} = \mathbf{a} \otimes \mathbf{b}\)

时间复杂度 \(O(n^2)\)

求值

秦九韶算法:

\[\begin{aligned} f(x) ={}& a_nx^n+a_{n-1}x^{n-1}+\cdots a_1x+a_0 \\ ={}& (a_nx^{n-1}+a_{n-1}x^{n-2}+\cdots a_1)x+a_0 \\ ={}& ((\cdots(a_nx+a_{n-1})x+a_{n-2})x+\cdots))x+a_0 \end{aligned} \]

时间复杂度 \(O(n)\)

点值表示法:

\[\{(x_0, A(x_0)), (x_1, A(x_1)), \cdots(x_n, A(x_n))\} \]

加法

\[C(x_i) = A(x_i) + B(x_i) \]

时间复杂度 \(O(n)\)

乘法

\[C(x_i) = A(x_i)B(x_i) \]

时间复杂度 \(O(n)\)

求值

拉格朗日插值法:

\[A(x)=\sum \limits_{i=1}^n A(x_i) \frac{\prod \limits _{j\ne i}^{n}(x- x_j)}{\prod \limits_{j \ne i}^{n}(x_i-x_j) } \]

时间复杂度 \(O(n^2)\)

代码实现:

double res = 0;
for(int i = 1; i <= n; i ++ )
{
    double s1 = 1, s2 = 1;
    for(int j = 1; j <= n; j ++ )
        if(j != i)
        {
            s1 = s1 * (k - x[j]);
            s2 = s2 * (x[i] - x[j]);
        }

    res = res + y[i] * s1 / s2;
}

复数

基本运算

高中数学基础,不多介绍

\[(a+bi) + (c+di) = (a+c) + (b + d)i \]

\[(a + bi) - (c + di) = (a - c) + (b - d)i \]

\[(a + bi)(c + di) = (ac - bd) + (ad + bc)i \]

\[\frac{a + bi}{c + di} = \frac{ac+bd}{c^2 + d^2} + \frac{bc-ad}{c^2+d^2}i \]

欧拉定理

欧拉定理:\(e^{i\theta} = \cos \theta + i\sin \theta\)

根据这个定理,我们可以把复数转化成在复平面上的一堆点来进行一些公式的推导。

算术基本定理

\(n\) 次代数方程 \(f(x)=a_nx^n+a_{n-1}x^{n-1}+\cdots+a_1x+a_0 = 0(a_n \ne 0 且 n > 0)\) 在复数域上恰好有 \(n\) 个根。

根据这个定理,我们可以找出一个特殊的式子:\(x^n=1\) 的根,在复数域中恰好有 \(n\) 个,设 \(\omega _n = e^{\frac{2\pi}{n}i}\),则 \(x^n = 1\) 的解集可以记作 \(\{ \omega _n^k|k=0,1,2,\cdots,n-1 \}\)

单位根

单位根具有一些性质,在 FFT 的时候会有很大用处。

折半引理

\[\omega _{2n}^{k+n} = -\omega _{2n}^{k} \]

这个结论在复平面中很好证明,当然也可以用公式来证明:

\[\begin{aligned} \omega _{2n}^{k+n} &= e^{i\frac{2\pi}{2n}(k+n)} \\ &=e^{i\frac{2\pi}{2n}k+i\pi} \\ &= e^{i\frac{2\pi}{2n}k}\times(-1) \\ &= -\omega _{2n}^k \end{aligned} \]

消去引理

\[\omega _{dn}^{dk} = \omega _{n}^{k} \]

证明:

\[\begin{aligned} \omega _{dn}^{dk} &= e^{i\frac{2\pi}{dn}dk } \\ &=e^{i\frac{2\pi}{n} k} \\ &= \omega ^{k}_{n} \end{aligned} \]

求和引理

\[\sum \limits_{i=0}^{n-1} \omega_n^i= 0 \]

证明:

\[\begin{aligned} \sum \limits_{i=0}^{n-1} \omega_n^i &= \frac{(\omega_n^k)^n -1}{\omega _n^k-1} \\ &=\frac{(\omega _n^n)^k-1}{\omega _n^k-1} \\ &= 0 \end{aligned} \]

FFT

在计算多项式乘法时,系数表示法的复杂度为 \(O(n^2)\),而点值表示法的复杂度仅为 \(O(n)\),我们可以尝试将系数表示法转化为点值表示法来进行运算。

我们可以对多项式进行如下变换,将奇数项和偶数项分开:

(前提:\(n\) 为奇数且 \(n+1\)\(2\) 的幂)

\[\begin{aligned} A(x) &= a_nx^n+a_{n-1}x^{n-1}+\cdots+a_1x+a_0\\ &= a_nx^n+a_{n-2}x^{n-2}+\cdots+a_1x+a_{n-1}x^{n-1}+a_{n-3}x^{n-3}+\cdots+a_0\\ \end{aligned} \]

将奇数项的系数构成的多项式记为 \(A^{[1]}\),偶数项记为 \(A^{[0]}\),就可以得到下面的式子。

\[\begin{aligned} A(x)&= a_nx^n+a_{n-2}x^{n-2}+\cdots+a_1x+a_{n-1}x^{n-1}+a_{n-3}x^{n-3}+\cdots+a_0\\ &= A^{[0]}(x^2)+xA^{[1]}(x^2) \end{aligned} \]

这时候我们想到,可以对两边的多项式分治处理,但是实数中没有可以把这个平方消掉的数,但是我们可以从复数中找到。

\(\omega _{2n}^k\) 带入 \(A\) 中,就可以得到:

\[\begin{aligned} A(\omega _{2n}^k)&= A^{[0]}(\omega _{2n}^{2k})+\omega _{2n}^kA^{[1]}(\omega _{2n}^{2k}) \\ &= A^{[0]}(\omega _n^k)+\omega _{2n}^kA^{[1]}(\omega _{n}^{k}) \end{aligned} \]

我们就可以对 \(A^{[0]}\)\(A^{[1]}\) 进行递归处理。

同时,在复平面中,\(\omega _{2n}^{n+k}\)\(\omega _{2n}^{k}\) 在多项式 \(A^{[0]}(x^2)\)\(A^{[1]}(x^2)\) 的值相同,因此我们也可以算出该点的值。

\[\begin{aligned} A(\omega _{2n}^{n+k})&= A^{[0]}(\omega _n^k)+\omega _{2n}^{n+k}A^{[1]}(\omega _{n}^{k})\\ &=A^{[0]}(\omega _n^k)-\omega _{2n}^{k}A^{[1]}(\omega ^k_n) \end{aligned} \]

这就是 FFT 的实现原理。做完 FFT 之后,我们就得到了一个多项式在复数域中 \(2n\) 个点的点值,这时再将两个多项式乘起来,就可以得到乘积的点值表示。

但是如果我们需要的是系数表示,我们则需要做一遍快速傅里叶逆变换,将点值转化为系数。

我们可以列出刚才求点值时的矩阵表示:

\[\begin{bmatrix}y_0 \\ y_1 \\ y_2 \\ y_3 \\ \vdots \\ y_{n-1} \end{bmatrix}=\begin{bmatrix}1 & 1 & 1 & 1 & \cdots & 1 \\1 & \omega_n^1 & \omega_n^2 & \omega_n^3 & \cdots & \omega_n^{n-1} \\1 & \omega_n^2 & \omega_n^4 & \omega_n^6 & \cdots & \omega_n^{2(n-1)} \\1 & \omega_n^3 & \omega_n^6 & \omega_n^9 & \cdots & \omega_n^{3(n-1)} \\\vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \omega_n^{3(n-1)} & \cdots & \omega_n^{(n-1)^2} \end{bmatrix}\begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \vdots \\ a_{n-1} \end{bmatrix} \]

我们现在只需要求出逆矩阵即可。由于这个矩阵元素非常特殊(没记错应该叫范德蒙德矩阵),它的逆矩阵即为每一项的倒数再除以变换的长度。

为了取倒数,我们可以做变换:

\[\begin{aligned} \omega _k^{-1}&= e^{-\frac{2\pi}{k}i } \\ &= \cos (-\frac{2\pi}{k})+\sin(-\frac{2\pi}{k})\\ &=\cos(\frac{2\pi}{k})-\sin(\frac{2\pi}{k}) \end{aligned} \]

我们可以看到与 FFT 时仅差了一个负号,因此我们可以将 FFT 和 IDFT 结合为一题,传参时传入一个 \(opt\),为 \(1\) 时执行 FFT,为 \(-1\) 时执行 IDFT。

代码实现:

#include <bits/stdc++.h>
using namespace std;
const double PI = acos(-1);
const int N = 3e6 + 10;
struct Complex
{
    double a, b;
    Complex() {a = 0, b = 0;}

    Complex(double real, double imag): a(real), b(imag) {}

    Complex operator + (const Complex& x) const
    {
        return Complex(a + x.a, b + x.b);
    }

    Complex operator - (const Complex& x) const
    {
        return Complex(a - x.a, b - x.b);
    }

    Complex operator * (const Complex& x) const
    {
        return Complex(a * x.a - b * x.b, a * x.b + b * x.a);
    }
}F[N], G[N];

void FFT(Complex a[], int lim, int opt)
{
    if(lim == 1) return;
    Complex a1[lim >> 1], a2[lim >> 1];
    for(int i = 0; i < lim; i += 2)
        a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
    FFT(a1, lim >> 1, opt);
    FFT(a2, lim >> 1, opt);
    Complex wn = Complex(cos(2.0 * PI / lim), opt * sin(2.0 * PI / lim));
    Complex w = Complex(1, 0);
    for(int k = 0; k < (lim >> 1); k ++ )
    {
        a[k] = a1[k] + w * a2[k];
        a[k + lim / 2] = a1[k] - w * a2[k];
        w = w * wn;
    }
}

int n, m;

int main()
{
    n = read(), m = read();
    for(int i = 0; i <= n; i ++ ) scanf("%lf", &F[i].a);
    for(int i = 0; i <= m; i ++ ) scanf("%lf", &G[i].a);

    int lim = 1;
    while(lim <= n + m) lim <<= 1;
    FFT(F, lim, 1), FFT(G, lim, 1);

    for(int i = 0; i <= lim; i ++ ) F[i] = F[i] * G[i];
    FFT(F, lim, -1);

    for(int i = 0; i <= n + m; i ++ ) cout << (int)(F[i].a / lim + 0.5) << ' ';

    return 0;
}

优化

BitReverse

上面我们使用了递归实现 FFT,我们可以思考一下能否用迭代来实现 FFT。

我们可以观察一下每次递归,都是哪些元素被分到了一组。接下来拿 \(0\sim 7\) 来举例。

\[\begin{aligned}(0,1,2,3,4,5,6,7)\\(0,2,4,6)(1,3,5,7)\\(0,4)(2,6)(1,5)(3,7)\\0,4,2,6,1,5,3,7\end{aligned} \]

乍一看好像毫无规律,但是我们把最后分的一组中二进制表示出来:

\[000,100,010,110,001,101,011,111 \]

(以下所说的数字全为二进制)

我们可以发现,\(000\) 处在了第 \(000\) 号位,\(100\) 处在了第 \(001\) 号位,\(010\) 处在了第 \(010\) 号位,\(110\) 处在了第 \(011\) 号位......

结论已经显然易见了,把二进制位翻转一下即为 FFT 时的排列顺序,因此我们可以通过这个操作来增加效率,我们把这个操作称为位逆序置换。

不难发现其中有 \(4\) 条规律:

  1. 奇数位二进制首位为 \(1\),偶数位二进制首位为 \(0\)
  2. 每两个二进制数除了首位都相同。
  3. 每两个数除首位外是上一级子问题的解
  4. 前一半数除末位外是上一级子问题的解

因此我们可以写出下面的递推代码:rev[i] = (rev[i >> 1] >> 1) | (i & 1) << (len - 1)

蝴蝶操作

在代码中,有一段操作:

a[k] = a1[k] + w * a2[k];
a[k + lim / 2] = a1[k] - w * a2[k];

其中都包含了共同的一项 w * a2[k],我们可以将这一项单独存储,减小常数。

同时我们也可以将 a1,a2 数组省去,直接在原数组上进行操作。

代码实现:

#define LOCAL
#include <bits/stdc++.h>
using namespace std;

const double PI = acos(-1);
int lim, len;
int rev[N];

void FFT(Complex a[], int opt)
{
    for(int i = 0; i < lim; i ++ ) 
        if(i < rev[i])
            swap(a[i], a[rev[i]]);
    int up = log2(lim);
    for(int dep = 1; dep <= up; dep ++ )
    {
        int m = 1 << dep;
        Complex wn = Complex(cos(2.0 * PI / m), opt * sin(2.0 * PI / m));
        for(int k = 0; k < lim; k += m)
        {
            Complex w = Complex(1, 0);
            for(int j = 0; j < m / 2; j ++ )
            {
                Complex t = w * a[k + j + m / 2];
                Complex u = a[k + j];
                a[k + j] = u + t;
                a[k + j + m / 2] = u - t;
                w = w * wn;
            }
        }
    }
    if(opt == -1)
    {
        for(int i = 0; i < lim; i ++ ) a[i].a /= lim;
    }
}

int n, m;

int main()
{
    n = read(), m = read();
    for(int i = 0; i <= n; i ++ ) scanf("%lf", &F[i].a);
    for(int i = 0; i <= m; i ++ ) scanf("%lf", &G[i].a);

    lim = 1;
    while(lim <= n + m) lim <<= 1, len ++;
    for(int i = 0; i < lim; i ++ )  
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));

    FFT(F, 1), FFT(G, 1);

    for(int i = 0; i <= lim; i ++ ) F[i] = F[i] * G[i];
    FFT(F, -1);

    for(int i = 0; i <= n + m; i ++ ) cout << (int)(F[i].a + 0.5) << ' ';

    return 0;
}
posted @ 2023-07-10 18:33  crimson000  阅读(13)  评论(0编辑  收藏  举报