快速傅里叶变换(FFT)

快速傅里叶变换(FFT)

前言

本文为个人学习笔记,大量参考了 oi-wiki 以及其他博客的内容。

问题

记:

f(x)=c0+c1x+c2x2++cnxng(x)=d0+d1x+d2x2++dmxmh(x)=f(x)×g(x)

O(nlogn) 内解决两个多项式乘法后的系数(即给定 f(x)g(x) 的系数,要你求出 h(x) 的系数)。

分析

暴力显然是 O(n2) 的,优化的想法是先考虑点值表示,再考虑从点值表示转换为系数表示。

具体如下:

点值表示的意思是,你需要求出(ωnk 是什么先忽略,当作是 n 个已知量即可):

f(ωn0),f(ωn1),f(ωnn1)g(ωn0),g(ωn1),g(ωnn1)

那么:

h(ωnk)=f(ωnk)×g(ωnk)

实际上,n 个点的点值表示法也确定了一个 n1 次的多项式,因此,一定存在某个算法能将点值表示法转化为系数表示(这个后面再说)。

至此,FFT 的核心思想已经说清楚了,就是考虑求出 f,g 的点值表示,那么 h 的点值表示就可以在 O(n) 的复杂度内求出,而后再考虑从点值表示转化为系数表示。

问题一:求出某个多项式的点值表示(离散傅里叶变换 DFT)

实际上这个问题的真实含义是:怎么选取 ωnk 这个已知量才能使得在一个优秀的复杂度内求出多项式的点值表示。

ωnk 表示将复数坐标系的单位圆平均分成 n 份,从 x 轴逆时针出发的第 k 条分界箭头的复数表示。

选取这个 ωnk 的原因是它有某些性质,能 " 在一个优秀的复杂度内求出多项式的点值表示 "。

性质:

1) ωnk=ωn2k2       2) ωnk+n2=ωnk

然后开始推式子:

f1(x)=c0+c2x++cn2xn22f2(x)=c1+c3x++cn1xn22

显然有

f(x)=f1(x2)+xf2(x2)

ωnk(0k<n2) 代入有:

f(ωnk)=f1(ωn2k)+ωnkf2(ωn2k)=f1(ωn2k)+ωnkf2(ωn2k)

ωnk+n2(0k<n2) 代入有:

f(ωnk+n2)=f1(ωn2k+n)+ωnk+n2f2(ωn2k+n)=f1(ωn2k)ωnkf2(ωn2k)

递归求解即可,有 logn 层,时间复杂度为 O(nlogn),为了方便处理,一般把 n 处理为 (n+m) 的二次幂,多出来的部分系数补为 0 即可。

具体实现中有以下注意事项:

1、虚数可以使用 C++ STL 库中的 complex 类型;

代码

#include <bits/stdc++.h>
template < typename T >
inline void read(T &cnt) {
    cnt = 0; char ch = getchar(); bool op = 1;
    for (; ! isdigit(ch); ch = getchar())
        if (ch == '-') op = 0;
    for (; isdigit(ch); ch = getchar())
        cnt = cnt * 10 + ch - 48;
    cnt = op ? cnt : - cnt;
}

const int N = (1 << 22) + 5;
const double PI = acos(-1);

inline void FFT(std::complex < double > *A, int n) {
    if (n == 1) return;
    int m = (n >> 1);
    std::complex < double > A0[m], A1[m];
    for (int i = 0; i < m; ++ i) {
        A0[i] = A[i * 2];
        A1[i] = A[i * 2 + 1];
    }
    FFT(A0, m); FFT(A1, m); // 递归处理
    auto W = std::complex < double > (cos(2.0 * PI / n), sin(2.0 * PI / n)),
         w = std::complex < double > (1.0, 0.0); // 从 w_n^0 出发
    for (int i = 0; i < m; ++ i) { // 根据式子计算 A 即可
        A[i] = A0[i] + w * A1[i];
        A[i + m] = A0[i] - w * A1[i];
        w *= W; // 等价于 w_n^k -> w_n^{k + 1}
    }
}

int n, m;
std::complex < double > F[N], G[N];

int main() {
    read(n), read(m);
    for (int i = 0; i <= n; ++ i) {
        int x; read(x);
        F[i] = x;
    }
    for (int i = 0; i <= m; ++ i) {
        int x; read(x);
        G[i] = x;
    }
    int sum = 1;
    while (sum <= n + m) sum *= 2; // 补齐成二次幂
    FFT(F, sum);
    FFT(G, sum);
    for (int i = 0; i < sum; ++ i)
        F[i] *= G[i];
    return 0;   
}

优化

递归实在太慢了!

8 项多项式为例,模拟拆分的过程:

  • 初始序列为 {x0,x1,x2,x3,x4,x5,x6,x7}
  • 一次二分之后 {x0,x2,x4,x6},{x1,x3,x5,x7}
  • 两次二分之后 {x0,x4}{x2,x6},{x1,x5},{x3,x7}
  • 三次二分之后 {x0}{x4}{x2}{x6}{x1}{x5}{x3}{x7}

规律:其实就是原来的那个序列,每个数用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标。比如 x1 是 001,翻转是 100,也就是 4,而且最后那个位置确实是 4。我们称这个变换为位逆序置换(bit-reversal permutation),证明留给读者自证。

实际上,位逆序置换可以 O(n) 从小到大递推实现,设 len=2k,其中 k 表示二进制数的长度,设 R(x) 表示长度为 k 的二进制数 x 翻转后的数(高位补 0)。我们要求的是 R(0),R(1),,R(n1)

首先 R(0)=0

我们从小到大求 R(x)。因此在求 R(x) 时,R(x2) 的值是已知的。因此我们把 x 右移一位(除以 2),然后翻转,再右移一位,就得到了 x 除了(二进制)个位之外其它位的翻转结果。

考虑个位的翻转结果:如果个位是 0,翻转之后最高位就是 0。如果个位是 1,则翻转后最高位是 1,因此还要加上 len2=2k1。综上

R(x)=R(x2)2+(xmod2)×len2

举个例子:设 k=5len=(100000)2。为了翻转 (11001)2

  1. 考虑 (1100)2,我们知道 R((1100)2)=R((01100)2)=(00110)2,再右移一位就得到了 (00011)2
  2. 考虑个位,如果是 1,它就要翻转到数的最高位,即翻转数加上 (10000)2=2k1,如果是 0 则不用更改。

蝶形运算优化

已知 f1(ωn/2k)f2(ωn/2k) 后,需要使用下面两个式子求出 f(ωnk)f(ωnk+n/2)

f(ωnk)=f1(ωn/2k)+ωnk×f2(ωn/2k)f(ωnk+n/2)=f1(ωn/2k)ωnk×f2(ωn/2k)

使用位逆序置换后,对于给定的 n,k

  • f1(ωn/2k) 的值存储在数组下标为 k 的位置,f2(ωn/2k) 的值存储在数组下标为 k+n2 的位置。
  • f(ωnk) 的值将存储在数组下标为 k 的位置,f(ωnk+n/2) 的值将存储在数组下标为 k+n2 的位置。

因此可以直接在数组下标为 kk+n2 的位置进行覆写,而不用开额外的数组保存值。此方法即称为 蝶形运算,或更准确的,基 - 2 蝶形运算。

再详细说明一下如何借助蝶形运算完成所有段长度为 n2 的合并操作:

1、令段长度为 s=n2
2、同时枚举序列 {f1(ωn/2k)} 的左端点 lg=0,2s,4s,,N2s 和序列 {f2(ωn/2k)} 的左端点 lh=s,3s,5s,,Ns
3、合并两个段时,枚举 k=0,1,2,,s1,此时 f1(ωn/2k) 存储在数组下标为 lg+k 的位置,f2(ωn/2k) 存储在数组下标为 lh+k 的位置;
4、使用蝶形运算求出 f(ωnk)f(ωnk+n/2),然后直接在原位置覆写。

代码

#include <bits/stdc++.h>
template < typename T >
inline void read(T &cnt) {
    cnt = 0; char ch = getchar(); bool op = 1;
    for (; ! isdigit(ch); ch = getchar())
        if (ch == '-') op = 0;
    for (; isdigit(ch); ch = getchar())
        cnt = cnt * 10 + ch - 48;
    cnt = op ? cnt : - cnt;
}

const int N = (1 << 22) + 5;
const double PI = acos(-1);

int rev[N];

inline void change(std::complex < double > *A, int n) {
    for (int i = 0; i < n; ++ i) { // 求 R 数组
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) {
            rev[i] |= (n >> 1);
        }
    }

    for (int i = 0; i < n; ++ i) // 将原序列 变为 底层对应的序列
        if (i < rev[i]) std::swap(A[i], A[rev[i]]);
} 
inline void FFT(std::complex < double > *A, int n) {
    change(A, n);
    for (int m = 2; m <= n; m *= 2) { // m 是当前处理的每段长度
        auto W = std::complex < double > 
          (cos(2.0 * PI / m), sin(2.0 * PI / m));  
        for (int x = 0; x < n; x += m) { // x 是每段的开头
            auto w = std::complex < double > (1.0, 0.0);
            for (int i = x; i < x + m / 2; ++ i) { // 求出每段的点值表示 根据公式求即可
                auto A0 = A[i], A1 = A[i + m / 2];
                A[i] = A0 + w * A1;
                A[i + m / 2] = A0 - w * A1;
                w *= W;
            }
        }
    }
}


int n, m;
std::complex < double > F[N], G[N];

int main() {
    change(F, 8);
    read(n), read(m);
    for (int i = 0; i <= n; ++ i) {
        int x; read(x);
        F[i] = x;
    }
    for (int i = 0; i <= m; ++ i) {
        int x; read(x);
        G[i] = x;
    }
    int sum = 1;
    while (sum <= n + m) sum *= 2; // 补齐成二次幂
    FFT(F, sum);
    FFT(G, sum);
    for (int i = 0; i < sum; ++ i)
        F[i] *= G[i];
    FFT(F, sum);
    return 0;   
}

问题二:将点值表示转化为系数表示(傅里叶反变换 IDFT)

点值表示的矩阵形式为:

[f(ωn0)f(ωn1)f(ωn2)f(ωn3)f(ωnn1)]=[111111ωn1ωn2ωn3ωnn11ωn2ωn4ωn6ωn2(n1)1ωn3ωn6ωn9ωn3(n1)1ωnn1ωn2(n1)ωn3(n1)ωn(n1)2][a0a1a2a3an1]

怎么求系数 a 呢?根据线性代数的知识:

Ax=bx=A1b

如果能求出 A1,那么 A1b 也是两个多项式相乘的结果,FFT 即可。

唯一的问题变为怎么求解 A1

根据矩阵的逆的定义,有

A1A=E

V 为原矩阵,G 为逆矩阵,考虑最终落在 E(i,j) 的值:

E(i,j)=k=0n1G(i,k)V(k,j)=k=0n1G(i,k)ωnkj=[i==j]

引理

k 不是 n 的倍数时,

i=0n1ωnki=0

证明如下:

i=0n1ωnki=ωnkn11ωnk=111ωnk=0

G(i,k)=ωnik,则:

k=0n1G(i,k)ωnkj=k=0n1ωnikωnkj=k=0n1ωnk(ji)

ji 不为 n 的倍数(0)时,上式为 0;

反之,有:

k=0n1ωnk(ji)=k=0n1ωn0=n

再前面补个系数 1n 即可,故:

G(i,k)=1nωnik

int main() {
    change(F, 8);
    read(n), read(m);
    for (int i = 0; i <= n; ++ i) {
        int x; read(x);
        F[i] = x;
    }
    for (int i = 0; i <= m; ++ i) {
        int x; read(x);
        G[i] = x;
    }
    int sum = 1;
    while (sum <= n + m) sum *= 2; // 补齐成二次幂
    FFT(F, sum);
    FFT(G, sum);
    for (int i = 0; i < sum; ++ i)
        F[i] *= G[i];
    FFT(F, sum);
    std::reverse(F + 1, F + sum); // 从第一位开始翻转
								  // 翻转后变为 0 1-n, 2-n, ..., -1 
    							  // 实际上等价于 0, 1, 2, ..., n-1
    for (int i = 0; i <= n + m; ++ i) { // 四舍五入
        std::cout << (int)(F[i].real() / sum + 0.5) << ' ';
    }
    return 0;   
}
posted @   chzhc  阅读(83)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
levels of contents
点击右上角即可分享
微信分享提示