快速傅里叶变换的 numpy 实现
理论
快速傅里叶变换(FFT),即利用计算机计算离散傅里叶变换(DFT)的高效、快速计算方法的统称。而在计算机算法中,也可用来简化多项式的乘法和加法运算。具体方法是通过将多项式转换到变换域中,在变换域中进行乘法和加法,再从变换域转换到多项式。
现有长度为 \(n=2^s\) 序列和相应的多项式为
为方便,取符号 \(\omega_n=e^{-2\pi i/n}\),定义离散傅里叶变换(DFT)为
如果采用每个 \(y_k\) 单独计算的方法,则需要进行 \(n\) 次 \(f_{\pmb a}(x)\) 的计算。如果采用霍纳法则进行计算,每个 \(f_{pmb a}(x)\) 需要进行 \(n-1\) 次加法和 \(n-1\) 次乘法,总体时间复杂度为 \(\Theta(n^2)\)。
如果采用快速傅里叶变换则速度将加快到 \(\Theta(n\lg n)\),这主要是因为利用了 \(\omega_n\) 的周期性。
将序列 \(\pmb a\) 的奇索引和偶索引分离
得到相应的傅里叶变换,注意长度的变化,变为了原来的二分之一。
则
但是因为 \(\pmb y^{[0,1]}\) 的长度只有 \(n/2\),所以只能计算前 \(n/2\) 个,还需要计算后 \(n/2\) 个的方法。
由于 \(\omega_{n/2}^{k+n/2}=\omega_{n/2}^k\),因此 \(f_{\pmb a^{[0,1]}}(\omega_{n/2}^{k+n/2})=f_{\pmb a^{[0,1]}}(\omega_{n/2}^k)\)。且 \(\omega_n^{k+n/2}=-\omega_n^k\)。因此有对于 \(k\in0..(n/2-1)\)
故 \(f_{\pmb a^{[0,1]}}(\omega_{n/2}^{k+n/2})\) 只需要计算 \(n/2\) 个,正好和序列 \(\pmb a^{[0,1]}\) 长度相同,也就是计算 \(\pmb y^{[0,1]}\) 即可。因此计算 \(f_{\pmb a^{[0,1]}}\) 的方法和 \(f_{\pmb a}\) 没什么不同,可以采用递归的方法。
这个递归函数为
其中乘法为按位乘法。
注意到在前面的证明中只使用了 \(\omega_n\) 的周期性,显然这对于 \(\omega_n^{-1}\) 也是成立的。
而离散傅里叶逆变换(IDFT)为
只需要在上述证明和算法中将 \(\omega_n\) 全部替换为 \(\omega_n^{-1}\),再在最后除以 \(n\),则完成了逆运算的证明和算法。
时间复杂度为
实现
递归实现和迭代实现,只适用于 n
为 2 的幂的情况。
from itertools import count
from sympy import S
import numpy as np
def recursive_fft(a: np.ndarray, dtype=np.complex128) -> np.ndarray:
n = a.size
assert n & (n - 1) == 0 # n is power of 2
if a.dtype != dtype:
a = a.astype(dtype)
if n == 1:
return a.copy()
w_n = dtype(S('exp(-2 * pi * I / n)').subs('n', n))
w = w_n ** np.arange(n // 2, dtype=dtype)
a = a.reshape(-1, 2) # 用 reshape 不用新建数组
y0 = recursive_fft(a[:,0].reshape(-1), dtype)
y1 = recursive_fft(a[:,1].reshape(-1), dtype)
y1 *= w
y = np.hstack([
y0 + y1,
y0 - y1
])
return y
def recursive_ifft(y: np.ndarray, dtype=np.complex128) -> np.ndarray:
def _recursive_ifft(y):
n = y.size
if n == 1:
return y.copy()
w_n = dtype(S('exp(2 * pi * I / n)').subs('n', n))
w = w_n ** np.arange(n // 2, dtype=dtype)
y = y.reshape(-1, 2) # 用 reshape 不用新建数组
a0 = _recursive_ifft(y[:,0].reshape(-1))
a1 = _recursive_ifft(y[:,1].reshape(-1))
a1 *= w
a = np.hstack([
a0 + a1,
a0 - a1
])
return a
n = y.size
assert n & (n - 1) == 0 # n is power of 2
if y.dtype != dtype:
y = y.astype(dtype)
return _recursive_ifft(y) / n
def iterative_fft(a: np.ndarray, dtype=np.complex128) -> np.ndarray:
n = a.size
assert n & (n - 1) == 0 # n is power of 2
for s in count():
m = 1 << s
if m == n:
break
a = a.reshape(m, -1, 2).swapaxes(1, 2)
if a.dtype != dtype:
y = a.ravel().astype(dtype)
else:
y = a.flatten()
for s in count():
m = 1 << s
if m == n:
break
w_m = dtype(S('exp(-pi * I / m)').subs('m', m))
w = w_m ** np.arange(m, dtype=dtype)
y = y.reshape(-1, m << 1)
y0 = y[:,:m].copy()
y1 = y[:,m:] * w
y[:,:m] = y0 + y1
y[:,m:] = y0 - y1
return y
def iterative_ifft(y: np.ndarray, dtype=np.complex128) -> np.ndarray:
n = y.size
assert n & (n - 1) == 0 # n is power of 2
for s in count():
m = 1 << s
if m == n:
break
y = y.reshape(m, -1, 2).swapaxes(1, 2)
if y.dtype != dtype:
a = y.ravel().astype(dtype)
else:
a = y.flatten()
for s in count():
m = 1 << s
if m == n:
break
w_m = dtype(S('exp(pi * I / m)').subs('m', m))
w = w_m ** np.arange(m, dtype=dtype)
a = a.reshape(-1, m << 1)
a0 = a[:,:m].copy()
a1 = a[:,m:] * w
a[:,:m] = a0 + a1
a[:,m:] = a0 - a1
return a / n
在代码中使用了 sympy
的表达式,因此如果 dtype=object
那么可以获得结果的精确表达式,但不一定是最简式。
numpy
实现的要点
reshape()
和 swapaxes()
关于 a.reshape(m, -1, 2).swapaxes(1, 2)
的作用,参考递归实现,每次递归都分离了偶索引和奇索引。
关键在于递归和奇偶分离:
m
和reshape(m)
的作用在于指明递归次数,同时将每次递归分离到不同的坐标上,使每次递归之间(的坐标转换)互不影响。reshape(-1, 2).swapaxes(1, 2)
的作用在于奇偶分离,如果认为数据按先列后行的顺序排列,那么就是首先按奇偶分为两列,在旋转坐标轴将奇偶分为两行。
类型转换、一维化和坐标转换
为了提高效率,需要注意一下类型转换和坐标转换顺序。由于经过多次 reshape()
和 swapaxis()
,原有数组的内在坐标被转换了。而后续操作基本上是线性的、连续的操作,因此在转换完坐标后再进行类型转换和一维化 (注意 flatten()
和 ravel()
的使用),可以创建新的线性数组,方便后续操作,且遵循了操作不改变原数组的准则,时间复杂度也会减小。
slower
版本先进行类型转换且没有进行一维化,再进行坐标转换。faster
版本先进行坐标转换,再进行类型转换和一维化。
# slower
if y.dtype != dtype:
a = y.astype(dtype)
else:
a = y.copy()
for s in count():
m = 1 << s
if m == n:
break
a = a.reshape(m, -1, 2).swapaxes(1, 2)
# faster
for s in count():
m = 1 << s
if m == n:
break
y = y.reshape(m, -1, 2).swapaxes(1, 2)
if y.dtype != dtype:
a = y.ravel().astype(dtype)
else:
a = y.flatten()
数组的复制
还需要注意数组的复制。一般来说,使用 numpy
库复制数组之后时间会加快(在没有多余操作的前提下),具体看下列代码。
slower
版本的关键在于使用原数组记录相关变量,在这个过程中,没有使用新的变量,a0, a1
只是a
的视图而已,是a
的一部分。faster
版本使用copy()
和乘法,一个直接、一个间接地建立了新数组,因此速度加快了。
#slower
a = a.reshape(-1, m << 1)
a0 = a[:,:m]
a1 = a[:,m:]
a1 *= w
a0[:] += a1
a1 *= -2
a1[:] += a0
# faster
a = a.reshape(-1, m << 1)
a0 = a[:,:m].copy()
a1 = a[:,m:] * w
a[:,:m] = a0 + a1
a[:,m:] = a0 - a1
在递归的实现中,因为 hstack()
不改变原数组,所以传入参数时使用 reshape()
,而不是新建数组。在这里,选择了减小空间开销而不是减小时间开销。