快速傅里叶变换的 numpy 实现

理论

快速傅里叶变换(FFT),即利用计算机计算离散傅里叶变换(DFT)的高效、快速计算方法的统称。而在计算机算法中,也可用来简化多项式的乘法和加法运算。具体方法是通过将多项式转换到变换域中,在变换域中进行乘法和加法,再从变换域转换到多项式。

现有长度为 \(n=2^s\) 序列和相应的多项式为

\[\begin{align*} \pmb a&=(a_0,a_1,\ldots,a_{n-1})\\ f_{\pmb a}(x)&=\sum_{j=0}^{n-1}a_jx^j \end{align*} \]

为方便,取符号 \(\omega_n=e^{-2\pi i/n}\),定义离散傅里叶变换(DFT)为

\[\begin{align*} \pmb y&=(y_0,y_1,\ldots,y_{n-1})\\ y_k&=\left.f_{\pmb a}(x)\right\vert_{x=\omega_n^k}=\sum_{j=0}^{n-1}a_j\omega_n^{kj}\\ \end{align*} \]

如果采用每个 \(y_k\) 单独计算的方法,则需要进行 \(n\)\(f_{\pmb a}(x)\) 的计算。如果采用霍纳法则进行计算,每个 \(f_{pmb a}(x)\) 需要进行 \(n-1\) 次加法和 \(n-1\) 次乘法,总体时间复杂度为 \(\Theta(n^2)\)


如果采用快速傅里叶变换则速度将加快到 \(\Theta(n\lg n)\),这主要是因为利用了 \(\omega_n\) 的周期性。

\[\begin{align*} \omega_{dn}^{dk}&=\omega^k_n\\ \omega_n^{k+n}&=\omega_n^k \end{align*} \]

将序列 \(\pmb a\) 的奇索引和偶索引分离

\[\begin{align*} \pmb a^{[0]}&=(a_0,a_2,\ldots,a_{n-2})\\ \pmb a^{[1]}&=(a_1,a_3,\ldots,a_{n-1})\\ \end{align*} \]

得到相应的傅里叶变换,注意长度的变化,变为了原来的二分之一。

\[\begin{align*} f_{\pmb a^{[0]}}(x)&=\sum_{j=0}^{n/2-1}a_{2j}x^j\\ f_{\pmb a^{[1]}}(x)&=\sum_{j=0}^{n/2-1}a_{2j+1}x^j\\[1em] \pmb y^{[0]}&=(y_0,y_2,\ldots,y_{n-2})\\ \pmb y^{[1]}&=(y_1,y_3,\ldots,y_{n-1})\\ y^{[0]}_k&=f_{\pmb a^{[0]}}(\omega^k_{n/2})\\ y^{[1]}_k&=f_{\pmb a^{[1]}}(\omega^k_{n/2}) \end{align*} \]

\[\begin{align*} f_{\pmb a}(x)&=f_{\pmb a^{[0]}}(x^2)+xf_{\pmb a^{[1]}}(x^2)\\ y_k&=f_{\pmb a}(\omega_n^k)=f_{\pmb a^{[0]}}(\omega_n^{2k})+\omega_n^kf_{\pmb a^{[1]}}(\omega_n^{2k})\\ &=f_{\pmb a^{[0]}}(\omega_{n/2}^k)+\omega_n^kf_{\pmb a^{[1]}}(\omega_{n/2}^k) \end{align*} \]

但是因为 \(\pmb y^{[0,1]}\) 的长度只有 \(n/2\),所以只能计算前 \(n/2\) 个,还需要计算后 \(n/2\) 个的方法。

由于 \(\omega_{n/2}^{k+n/2}=\omega_{n/2}^k\),因此 \(f_{\pmb a^{[0,1]}}(\omega_{n/2}^{k+n/2})=f_{\pmb a^{[0,1]}}(\omega_{n/2}^k)\)。且 \(\omega_n^{k+n/2}=-\omega_n^k\)。因此有对于 \(k\in0..(n/2-1)\)

\[\begin{align*} y_k&=f_{\pmb a^{[0]}}(\omega_{n/2}^k)+\omega_n^kf_{\pmb a^{[1]}}(\omega_{n/2}^k)\\ &=y^{[0]}_k+\omega_n^ky_k^{[1]}\\ y_{k+n/2}&=f_{\pmb a^{[0]}}(\omega_{n/2}^{k+n/2})+\omega_n^{k+n/2}f_{\pmb a^{[1]}}(\omega_{n/2}^{k+n/2})\\ &=f_{\pmb a^{[0]}}(\omega_{n/2}^k)-\omega_n^kf_{\pmb a^{[1]}}(\omega_{n/2}^k)\\ &=y^{[0]}_k-\omega_n^ky_k^{[1]}\\ \end{align*} \]

\(f_{\pmb a^{[0,1]}}(\omega_{n/2}^{k+n/2})\) 只需要计算 \(n/2\) 个,正好和序列 \(\pmb a^{[0,1]}\) 长度相同,也就是计算 \(\pmb y^{[0,1]}\) 即可。因此计算 \(f_{\pmb a^{[0,1]}}\) 的方法和 \(f_{\pmb a}\) 没什么不同,可以采用递归的方法。

这个递归函数为

\[\begin{align*} \pmb\omega_n&=\left(\omega_n^0,\omega_n^1,\ldots,\omega_n^{n/2-1}\right)\\ \pmb y_{0..(n/2-1)}&=f(\pmb a)=f(\pmb a^{[0]})+\pmb\omega_n\cdot f(\pmb a^{[1]})\\ \pmb y_{(n/2)..(n-1)}&=f(\pmb a)=f(\pmb a^{[0]})-\pmb\omega_n\cdot f(\pmb a^{[1]})\\ \end{align*} \]

其中乘法为按位乘法。


注意到在前面的证明中只使用了 \(\omega_n\) 的周期性,显然这对于 \(\omega_n^{-1}\) 也是成立的。

而离散傅里叶逆变换(IDFT)为

\[\begin{align*} a_j&=\frac1n\sum_{k=0}^{n-1}y_k\omega_n^{-kj}=\frac1n\sum_{k=0}^{n-1}y_k(\omega_n^{-1})^{kj} \end{align*} \]

只需要在上述证明和算法中将 \(\omega_n\) 全部替换为 \(\omega_n^{-1}\),再在最后除以 \(n\),则完成了逆运算的证明和算法。

\[\begin{align*} \pmb\omega_n&=\left(\omega_n^{-0},\omega_n^{-1},\ldots,\omega_n^{-(n/2-1)}\right)\\ n\pmb a_{0..(n/2-1)}&=f(\pmb y)=f(\pmb y^{[0]})+\pmb\omega_n\cdot f(\pmb y^{[1]})\\ n\pmb a_{(n/2)..(n-1)}&=f(\pmb y)=f(\pmb y^{[0]})-\pmb\omega_n\cdot f(\pmb y^{[1]})\\ \end{align*} \]

时间复杂度为

\[\begin{align*} T(n)&=\begin{cases} \Theta(1)&n=1\\ 2T(n/2)+\Theta(n)&n=2^s,s\in\mathbb Z^+ \end{cases}\\ &=\Theta(n\lg n) \end{align*} \]

实现

递归实现和迭代实现,只适用于 n 为 2 的幂的情况。

from itertools import count
from sympy import S
import numpy as np


def recursive_fft(a: np.ndarray, dtype=np.complex128) -> np.ndarray:
    n = a.size
    assert n & (n - 1) == 0  # n is power of 2

    if a.dtype != dtype:
        a = a.astype(dtype)
    if n == 1:
        return a.copy()

    w_n = dtype(S('exp(-2 * pi * I / n)').subs('n', n))
    w = w_n ** np.arange(n // 2, dtype=dtype)

    a = a.reshape(-1, 2)  # 用 reshape 不用新建数组
    y0 = recursive_fft(a[:,0].reshape(-1), dtype)
    y1 = recursive_fft(a[:,1].reshape(-1), dtype)

    y1 *= w
    y = np.hstack([
        y0 + y1,
        y0 - y1
    ])

    return y


def recursive_ifft(y: np.ndarray, dtype=np.complex128) -> np.ndarray:
    def _recursive_ifft(y):
        n = y.size
        if n == 1:
            return y.copy()

        w_n = dtype(S('exp(2 * pi * I / n)').subs('n', n))
        w = w_n ** np.arange(n // 2, dtype=dtype)

        y = y.reshape(-1, 2)  # 用 reshape 不用新建数组
        a0 = _recursive_ifft(y[:,0].reshape(-1))
        a1 = _recursive_ifft(y[:,1].reshape(-1))

        a1 *= w
        a = np.hstack([
            a0 + a1,
            a0 - a1
        ])

        return a

    n = y.size
    assert n & (n - 1) == 0  # n is power of 2
    if y.dtype != dtype:
        y = y.astype(dtype)

    return _recursive_ifft(y) / n


def iterative_fft(a: np.ndarray, dtype=np.complex128) -> np.ndarray:
    n = a.size
    assert n & (n - 1) == 0  # n is power of 2

    for s in count():
        m = 1 << s
        if m == n:
            break
        a = a.reshape(m, -1, 2).swapaxes(1, 2)

    if a.dtype != dtype:
        y = a.ravel().astype(dtype)
    else:
        y = a.flatten()

    for s in count():
        m = 1 << s
        if m == n:
            break
        w_m = dtype(S('exp(-pi * I / m)').subs('m', m))
        w = w_m ** np.arange(m, dtype=dtype)

        y = y.reshape(-1, m << 1)
        y0 = y[:,:m].copy()
        y1 = y[:,m:] * w

        y[:,:m] = y0 + y1
        y[:,m:] = y0 - y1

    return y


def iterative_ifft(y: np.ndarray, dtype=np.complex128) -> np.ndarray:
    n = y.size
    assert n & (n - 1) == 0  # n is power of 2

    for s in count():
        m = 1 << s
        if m == n:
            break
        y = y.reshape(m, -1, 2).swapaxes(1, 2)

    if y.dtype != dtype:
        a = y.ravel().astype(dtype)
    else:
        a = y.flatten()

    for s in count():
        m = 1 << s
        if m == n:
            break
        w_m = dtype(S('exp(pi * I / m)').subs('m', m))
        w = w_m ** np.arange(m, dtype=dtype)

        a = a.reshape(-1, m << 1)
        a0 = a[:,:m].copy()
        a1 = a[:,m:] * w

        a[:,:m] = a0 + a1
        a[:,m:] = a0 - a1

    return a / n

在代码中使用了 sympy 的表达式,因此如果 dtype=object 那么可以获得结果的精确表达式,但不一定是最简式。

numpy 实现的要点

reshape()swapaxes()

关于 a.reshape(m, -1, 2).swapaxes(1, 2) 的作用,参考递归实现,每次递归都分离了偶索引和奇索引。

关键在于递归和奇偶分离:

  • mreshape(m) 的作用在于指明递归次数,同时将每次递归分离到不同的坐标上,使每次递归之间(的坐标转换)互不影响。
  • reshape(-1, 2).swapaxes(1, 2) 的作用在于奇偶分离,如果认为数据按先列后行的顺序排列,那么就是首先按奇偶分为两列,在旋转坐标轴将奇偶分为两行。

类型转换、一维化和坐标转换

为了提高效率,需要注意一下类型转换和坐标转换顺序。由于经过多次 reshape()swapaxis(),原有数组的内在坐标被转换了。而后续操作基本上是线性的、连续的操作,因此在转换完坐标后再进行类型转换和一维化 (注意 flatten()ravel() 的使用),可以创建新的线性数组,方便后续操作,且遵循了操作不改变原数组的准则,时间复杂度也会减小。

  • slower 版本先进行类型转换且没有进行一维化,再进行坐标转换。
  • faster 版本先进行坐标转换,再进行类型转换和一维化。
# slower
if y.dtype != dtype:
    a = y.astype(dtype)
else:
    a = y.copy()

for s in count():
    m = 1 << s
    if m == n:
        break
    a = a.reshape(m, -1, 2).swapaxes(1, 2)

# faster
for s in count():
    m = 1 << s
    if m == n:
        break
    y = y.reshape(m, -1, 2).swapaxes(1, 2)

if y.dtype != dtype:
    a = y.ravel().astype(dtype)
else:
    a = y.flatten()

数组的复制

还需要注意数组的复制。一般来说,使用 numpy 库复制数组之后时间会加快(在没有多余操作的前提下),具体看下列代码。

  • slower 版本的关键在于使用原数组记录相关变量,在这个过程中,没有使用新的变量,a0, a1 只是 a 的视图而已,是 a 的一部分。
  • faster 版本使用 copy() 和乘法,一个直接、一个间接地建立了新数组,因此速度加快了。
#slower
a = a.reshape(-1, m << 1)
a0 = a[:,:m]
a1 = a[:,m:]
a1 *= w

a0[:] += a1
a1 *= -2
a1[:] += a0

# faster
a = a.reshape(-1, m << 1)
a0 = a[:,:m].copy()
a1 = a[:,m:] * w

a[:,:m] = a0 + a1
a[:,m:] = a0 - a1

在递归的实现中,因为 hstack() 不改变原数组,所以传入参数时使用 reshape(),而不是新建数组。在这里,选择了减小空间开销而不是减小时间开销。

posted @ 2022-12-07 00:36  Violeshnv  阅读(38)  评论(0编辑  收藏  举报