快速傅里叶变换的倍增实现
本文作者为 JustinRochester。
快速傅里叶变换的倍增实现
傅里叶变换的顺序
我们观察 第五篇 中的算法流程,可以发现:
对于 FFT 的前半部分,我们一直在奇偶分列,然后递归求解;对于后半部分,我们一直在 \(O(n)\) 合并,然后向上返回。
对于这种分治后,子区间之间、子区间的子区间...之间的运算相互独立的情况,我们除了可以按递归树进行先序遍历(DFS、栈的顺序),也可以按递归树进行层序遍历(BFS、队列的顺序)。
因此,一种处理方法便是:
- 从 \([0,2^k)\) 的区间开始奇偶分列,然后分别遍历第二层的 \([0,2^{k-1}), [2^{k-1},2^k)\) 奇偶分裂,再遍历第三层......依此类推。
- 从 \([0,1),[1,2),[2,3),\cdots,[2^k-1, 2^k)\) 的区间开始 \(O(n)\) 合并,然后倒数第二层......再依此类推。
这种遍历方法的好处在于后续的 \(O(n)\) 合并过程中,每一层的步长是一样的,只需要 \(O(\log n)\) 次修改步长即可。
蝴蝶变换
我们先考虑一下这种遍历方法的第一步,递归的奇偶分列。
可以预见的是,这一步的过程中,实际上我们并没有进行任何计算,我们只是在不停的对各个元素的位置进行重新排列。
当然,重新排列的反复复合也是一个重新排列。我们如果能直接知道各个元素重新排列后的最终位置,我们直接 \(O(n)\) 替换到该位置显然是更优的。
我们考虑如何确定每个元素最终的位置。
在第一次的奇偶分列中,我们将偶数放在前面,奇数放在后面。这种分奇偶的考虑方法和 \(\bmod 2\) 的余数是等价的,也即是二进制的最低位。因此,在第一轮的奇偶分列中,二进制末位为 \(0\) 的放在了前面,二进制末位为 \(1\) 的放在了后面。
在后续的递归中,我们将 \(f_0, f_2, f_4, \cdots\) 组织在一起,重新编号为 \(fl_0, fl_1, fl_2, \cdots\) ;而将 \(f_1, f_3, f_5, \cdots\) 组织在一起,重新编号为 \(fr_0, fr_1, fr_2, \cdots\)
可以注意到,由于我们进行了奇偶分组,编号为 \(2t,2t+1\) 的元素被重新编号为了 \(t\) ;这和直接除以 \(2\) 是等价的,也即是二进制舍去最后一位。
于是我们知道了,每一次奇偶分列就相当于将所有数的二进制末位取出来,然后按照二进制末位为第一关键字、剩余数字为第二关键字进行排序。
一个代码实现方法是:
struct position{
int ini_pos, end_sta, tmp_pos;
};//tmp_pos == ini_pos at the beginning.
bool cmp(position &a, position &b) {
if(a.end_sta != b.end_sta)
return a.end_sta < b.end_sta;
return a.tmp_pos < b.tmp_pos;
}
void even_odd_split(position a[N], int l, int r) {
for(int i=l; i<r; ++i) {
a[i].end_sta = tmp_pos&1;
a[i].tmp_pos >>= 1;
}
sort(a+l, a+r, cmp);
}
void get_pos(position a[N], int N) {
for(int len=N; len>=1; len>>=1)
for(int l=0; l<N; l+=len)
even_odd_split(a, l, l+len);
}
而如果我们每次取出的末位都按顺序留起来,排序的时候先按第一次取出的末位排,再按第二次取出的末位排......依此类推,直到按这一次取出的末位排,最后按剩余的数字排。这样排序的结果是等价的。
即修改为:
void even_odd_split(position a[N], int N, int len) {
for(int l=0; l<N; l+=len)
for(int i=l; i<l+len; ++i) {
int current = a[i].tmp_pos & 1;
a[i].tmp_pos >>= 1;
a[i].end_sta = (a[i].end_sta << 1) | current;
}
sort(a, a+N, cmp);
}
void get_pos(position a[N], int N) {
for(int len=N; len>=1; len>>=1)
even_odd_split(a, N, len);
}
等等,那这不是每次把最后一位取出来,放在一个新数的最后面;这不就是二进制位翻转吗?
确实,我们只需要求出 \(i\) 的二进制位翻转 \(rev_i\) ,将 \(a_i\) 和 \(a_{rev_i}\) 交换就得到了最终位置。
这个优美的变换被称之为“蝴蝶变换”。
\(O(n)\) 求解 \(rev_i\)
如果我们暴力求解 \(rev_i\) ,每个都需要 \(O(\log n)\) 的复杂度,那么总复杂度就又回到了 \(O(n\log n)\) 才能完成交换。
虽然不能说效率没有提升,但也是相当有限的。
我们现在就希望能够找到方法,\(O(n)\) 求解 \(rev_i\) 。
我们考虑 \(i\) 去掉最后一位的 \(i>>1\) ,它的翻转是 \(rev_{i>>1}\) ,那么 \(i\) 的二进制翻转就是把最后一位放到最前面 \(((i\&1)<<k-1) | rev_{i>>1}\) 。
但是,这个做法是不对的。原因在于 \(rev_{i>>1}\) 由于也是 \(k\) 位二进制数;而我们需要的是 \(k-1\) 位的二进制翻转。
但比较庆幸的是,\(k\) 位的 \(rev_{i>>1}\) 由于最高位为 \(0\),和 \(k-1\) 位的只相差一个末位 \(0\)。
我们只需要修改为 \(rev_i=((i\&1)<<k-1) | (rev_{i>>1}>>1)\) 即可。
由于求解 \(rev_i\) 时,\(rev_{i>>1}\) 已经计算得到了,于是我们知道了 \(O(n)\) 求解的方法:
int rev[MAXN];
void get_rev(int k) {
for(int i=0; i<1<<k; ++i)
rev[i]=((i&1)<<k-1) | (rev[i>>1]>>1);
}
void butterfly(vir *a, int N) {
for(int i=0; i<N; ++i)
if(i<rev[i])//避免反复翻转
swap(a[i], a[rev[i]]);
}
倍增实现 FFT
类似上文的思想,我们稍加修改就能得到了第二步 \(O(n)\) 合并的倍增实现:
inline void FFT(vir *f, int N) {
for(int i=0; i<N; ++i)
if(i<rev[i])
swap(f[i], f[rev[i]]);
for(int n=2; n<=N; n<<=1) {
int pace=N/n;
vir *o, x, y;
for(int l=0; l<N; l+=n) {
o=w;
for(int i=l; i<l+(n>>1); ++i) {
x=f[i] + *o * f[i+(n>>1)];
y=f[i] - *o * f[i+(n>>1)];
o += pace;
f[i] = x;
f[i+(n>>1)] = y;
}
}
}
}
或稍加修改成为:
inline void FFT(vir *f, int N) {
for(int i=0; i<N; ++i)
if(i<rev[i])
swap(f[i], f[rev[i]]);
for(int n=1, pace=N>>1; n<N; n<<=1, pace>>=1)
for(vir *l=f,*o=w; l!=f+N; l+=n<<1, o=w)
for(vir *x=l, y; x!=l+n; ++x, o+=pace) {
y = x[0] + *o * x[n];
x[n] = x[0] - *o * x[n];
x[0] = y;
}
}