FFT 模板
FFT(Fast Fourier Transform),确切地说应该称之为 FDFT(Fast Discrete Fourier Transform),因为 FFT 是为了解决 DFT 问题而设计的一种快速算法。在深入讨论之前,有必要特别指出这一点。
DFT 问题:
给定复数域上的 $n-1$ 次多项式 $A(x)$ 的系数表示(coefficient representation)$(a_0, a_1,\dots, a_{n-1})$,求 $A(x)$ 的某个点-值表示(point-value representation):
\[((x_0, y_0), (x_1, y_1), (x_2, y_2), \dots, (x_{n-1}, y_{n-1}))\]
FFT 的非递归(迭代)实现 Version I 手写 Complex 类(《算法导论》)
1 #include <bits/stdc++.h> 2 #define rep(i, l, r) for(int i=l; i<r; i++) 3 using namespace std; 4 const double PI(acos(-1)); 5 6 struct Complex{ 7 double r, i; 8 Complex(double r, double i):r(r), i(i){} 9 Complex(int n):r(cos(2*PI/n)), i(sin(2*PI/n)){} //!!error-prone 10 Complex():r(0), i(0){} //default constructor 11 Complex &operator*=(const Complex &a){ 12 double R=r*a.r-i*a.i, I=r*a.i+a.r*i; 13 r=R, i=I; 14 return *this; 15 } 16 Complex operator+(const Complex a){ 17 return Complex(r+a.r, i+a.i); 18 } 19 Complex operator-(const Complex a){ 20 return Complex(r-a.r, i-a.i); 21 } 22 Complex operator*(const Complex a){ 23 return Complex(r*a.r-i*a.i, r*a.i+a.r*i); 24 } 25 void out(){ 26 cout<<r<<' '<<i<<endl; 27 } 28 }; 29 30 const int N(1<<17); 31 int ans[N]; 32 Complex a[N], b[N]; 33 char s[N], t[N]; 34 35 void bit_reverse_swap(Complex *a, int n){ 36 for(int i=1, j=n>>1, k; i<n-1; i++){ 37 if(i < j) swap(a[i],a[j]); 38 //tricky 39 for(k=n>>1; j>=k; j-=k, k>>=1); //inspect the highest "1" 40 j+=k; 41 } 42 } 43 44 void FFT(Complex* a, int n, int t){ 45 bit_reverse_swap(a, n); 46 for(int i=2; i<=n; i<<=1){ 47 Complex wi(i*t); 48 for(int j=0; j<n; j+=i){ 49 Complex w(1, 0); 50 for(int k=j, h=i>>1; k<j+h; k++){ 51 Complex t=w*a[k+h], u=a[k]; 52 a[k]=u+t; 53 a[k+h]=u-t; 54 w*=wi; 55 } 56 } 57 } 58 if(t==-1) rep(i, 0, n) a[i].r/=n; //!!error-prone 59 } 60 61 int trans(int x){ 62 int i=0; 63 for(; x>1<<i; i++); 64 return 1<<i; 65 } 66 67 int main(){ 68 for(; ~scanf("%s%s", s, t); ){ 69 int n=strlen(s), m=strlen(t), l=trans(n+m-1); 70 rep(i, 0, n) a[i]=Complex(s[n-1-i]-'0', 0); 71 rep(i, n, l) a[i]=Complex(0, 0); 72 rep(i, 0, m) b[i]=Complex(t[m-1-i]-'0', 0); 73 rep(i, m, l) b[i]=Complex(0, 0); 74 75 FFT(a, l, 1), FFT(b, l, 1); 76 rep(i, 0, l) a[i]*=b[i]; 77 FFT(a, l, -1); 78 rep(i, 0, l) ans[i]=(int)(a[i].r+0.5); ans[l]=0; //error-prone 79 rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10; 80 int c=l; 81 for(;c && !ans[c]; --c); 82 for(; ~c; putchar(ans[c--]+'0')); //error-prone 83 puts(""); 84 } 85 return 0; 86 }
Comment:
1. 此代码是为 HDU1402 写的。代码中,凡注释 error-prone 处,都应特别小心。我犯的最傻逼的错误是第9行,应当是2*PI,我写成PI了。
2. FFT的数值稳定性(精度)问题,还有待考虑。(UPD)多次做多项式乘法时,精度损失较快,这时将 double 换成 long double 可缓解精度损失。
Comment:
1. bit_reverse_swap()函数是对算法导论上的bit_reverse_copy()的改进,将下标互为bit-reverse 的两元素互换位置,就免去了 copy 所需的空间。
2.bit_revrese_copy()不太好懂,需要一点解释:
1 void bit_reverse_swap(Complex *a, int n){
2 for(int i=1, j=n>>1, k; i<n-1; i++){
3 if(i < j) swap(a[i],a[j]);
4 //tricky
5 for(k=n>>1; j>=k; j-=k, k>>=1); //inspect the highest "1"
6 j+=k;
7 }
8 }
将$i$的bit-reverse记作$\rev(i)$。
(i). 由于 $\rev(0)=1, \rev(n-1)=n-1$($n$ 是 $2$ 的幂),所以第 2 行的主循环可令 $i$ 从 $1$ 循环到 $n-2$。同时 $j$ 从 $\rev(1)= \frac{n}{2}$,“循环”到 $\rev(n-2)$ 。
(ii). 第 3 行的判断 if(i < j) 避免了重复交换
(iii).第 5 行的循环的作用就是将 $j$ 从 $\rev(i)$ 变成 $\rev(i+1)$:
首先应当注意到,$i$ 的最低位恰是 $rev(i)$ 的最高位。若 $\rev(i)$ 的最高位是 $0$ 那么 $\rev(i+1)$ 就是 $\rev(i)+ \frac{n}{2}$,否则,$i$ 加上 $1$ 后,最低位将变成 $0$,并且向高一位进 $1$ 。相应的,$\rev(i+1)$ 的最高位应置 $0$(即代码中的 j-=k),并且向低一位"进“ $1$(对应代码中的 k>>=1)。这样从高位往低位检查,遇到 $1$(对应代码中的条件 j>=k)就进位,遇到 $0$ 就退出循环。
3. 我写代码时把第 58 行的 == 写成了 =,结果DEBUG 一个多小时。。。
Version II: 用 C++ 标准库中的 complex<double> 类,代码短一些,但也会慢一些:
1 #include <bits/stdc++.h>
2 #define rep(i, l, r) for(int i=l; i<r; i++)
3 using namespace std;
4 const double PI(acos(-1));
5 typedef complex<double> C;
6
7 const int N(1<<17);
8 int ans[N];
9 C a[N], b[N];
10 char s[N], t[N];
11
12 void bit_reverse_swap(C *a, int n){
13 for(int i=1, j=n>>1, k; i<n-1; i++){
14 if(i < j) swap(a[i],a[j]);
15 //tricky
16 for(k=n>>1; j>=k; j-=k, k>>=1); //inspect the highest "1"
17 j+=k;
18 }
19 }
20
21 void FFT(C* a, int n, int t){
22 bit_reverse_swap(a, n);
23 for(int i=2; i<=n; i<<=1){
24 C wi(cos(t*2*PI/i), sin(t*2*PI/i));
25 for(int j=0; j<n; j+=i){
26 C w(1);
27 for(int k=j, h=i>>1; k<j+h; k++){
28 C t=w*a[k+h], u=a[k];
29 a[k]=u+t;
30 a[k+h]=u-t;
31 w*=wi;
32 }
33 }
34 }
35 if(t==-1) rep(i, 0, n) a[i]/=n; //!!error-prone: typo ==/=
36 }
37
38 int trans(int x){
39 int i=0;
40 for(; x>1<<i; i++);
41 return 1<<i;
42 }
43
44 int main(){
45 for(; ~scanf("%s%s", s, t); ){
46 int n=strlen(s), m=strlen(t), l=trans(n+m-1);
47 rep(i, 0, n) a[i]=C(s[n-1-i]-'0');
48 rep(i, n, l) a[i]=C(0);
49 rep(i, 0, m) b[i]=C(t[m-1-i]-'0');
50 rep(i, m, l) b[i]=C(0);
51
52 FFT(a, l, 1), FFT(b, l, 1);
53 rep(i, 0, l) a[i]*=b[i];
54 FFT(a, l, -1);
55 rep(i, 0, l) ans[i]=(int)(a[i].real()+0.5); ans[l]=0; //error-prone
56 rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
57 int p=l;
58 for(;p && !ans[p]; --p);
59 for(; ~p; putchar(ans[p--]+'0')); //error-prone
60 puts("");
61 }
62 return 0;
63 }