FFT 模板

$\DeclareMathOperator{\rev}{rev}$

FFT(Fast Fourier Transform),确切地说应该称之为 FDFT(Fast Discrete Fourier Transform),因为 FFT 是为了解决 DFT 问题而设计的一种快速算法。在深入讨论之前,有必要特别指出这一点。

DFT 问题:

给定复数域上的 $n-1$ 次多项式 $A(x)$ 的系数表示(coefficient representation)$(a_0, a_1,\dots, a_{n-1})$,求 $A(x)$ 的某个点-值表示(point-value representation):

\[((x_0, y_0), (x_1, y_1), (x_2, y_2), \dots, (x_{n-1}, y_{n-1}))\]

 

 

FFT 的非递归(迭代)实现 Version I 手写 Complex 类(《算法导论》)

 

 1 #include <bits/stdc++.h>
 2 #define rep(i, l, r) for(int i=l; i<r; i++)
 3 using namespace std;
 4 const double PI(acos(-1));
 5 
 6 struct Complex{
 7     double r, i;
 8     Complex(double r, double i):r(r), i(i){}
 9     Complex(int n):r(cos(2*PI/n)), i(sin(2*PI/n)){}    //!!error-prone
10     Complex():r(0), i(0){}    //default constructor
11     Complex &operator*=(const Complex &a){
12         double R=r*a.r-i*a.i, I=r*a.i+a.r*i;
13         r=R, i=I;
14         return *this;
15     }
16     Complex operator+(const Complex a){
17         return Complex(r+a.r, i+a.i);
18     }
19     Complex operator-(const Complex a){
20         return Complex(r-a.r, i-a.i);
21     }
22     Complex operator*(const Complex a){
23         return Complex(r*a.r-i*a.i, r*a.i+a.r*i);
24     }
25     void out(){
26         cout<<r<<' '<<i<<endl;
27     }
28 };
29 
30 const int N(1<<17);
31 int ans[N];
32 Complex a[N], b[N];
33 char s[N], t[N];
34 
35 void bit_reverse_swap(Complex *a, int n){
36     for(int i=1, j=n>>1, k; i<n-1; i++){
37         if(i < j) swap(a[i],a[j]);
38         //tricky
39         for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
40         j+=k;
41     }
42 }
43 
44 void FFT(Complex* a, int n, int t){
45     bit_reverse_swap(a, n);
46     for(int i=2; i<=n; i<<=1){
47         Complex wi(i*t);
48         for(int j=0; j<n; j+=i){
49             Complex w(1, 0);
50             for(int k=j, h=i>>1; k<j+h; k++){
51                 Complex t=w*a[k+h], u=a[k];
52                 a[k]=u+t;
53                 a[k+h]=u-t;
54                 w*=wi;
55             }
56         }
57     }
58     if(t==-1) rep(i, 0, n) a[i].r/=n;    //!!error-prone
59 }
60 
61 int trans(int x){
62     int i=0;
63     for(; x>1<<i; i++);
64     return 1<<i;
65 }
66 
67 int main(){
68     for(; ~scanf("%s%s", s, t); ){
69         int n=strlen(s), m=strlen(t), l=trans(n+m-1);
70         rep(i, 0, n) a[i]=Complex(s[n-1-i]-'0', 0);
71         rep(i, n, l) a[i]=Complex(0, 0);
72         rep(i, 0, m) b[i]=Complex(t[m-1-i]-'0', 0);
73         rep(i, m, l) b[i]=Complex(0, 0);
74 
75         FFT(a, l, 1), FFT(b, l, 1);
76         rep(i, 0, l) a[i]*=b[i];
77         FFT(a, l, -1);
78         rep(i, 0, l) ans[i]=(int)(a[i].r+0.5); ans[l]=0;    //error-prone
79         rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
80         int c=l;
81         for(;c && !ans[c]; --c);
82         for(; ~c; putchar(ans[c--]+'0'));    //error-prone
83         puts("");
84     }
85     return 0;
86 }

 

 

 

Comment:

1. 此代码是为 HDU1402 写的。代码中,凡注释 error-prone 处,都应特别小心。我犯的最傻逼的错误是第9行,应当是2*PI,我写成PI了。

2. FFT的数值稳定性(精度)问题,还有待考虑。(UPD)多次做多项式乘法时,精度损失较快,这时将 double 换成 long double 可缓解精度损失。


Comment:

1. bit_reverse_swap()函数是对算法导论上的bit_reverse_copy()的改进,将下标互为bit-reverse 的两元素互换位置,就免去了 copy 所需的空间。

2.bit_revrese_copy()不太好懂,需要一点解释:

1 void bit_reverse_swap(Complex *a, int n){
2     for(int i=1, j=n>>1, k; i<n-1; i++){
3         if(i < j) swap(a[i],a[j]);
4         //tricky
5         for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
6         j+=k;
7     }
8 }

将$i$的bit-reverse记作$\rev(i)$。

(i). 由于 $\rev(0)=1, \rev(n-1)=n-1$($n$ 是 $2$ 的幂),所以第 2 行的主循环可令 $i$ 从 $1$ 循环到 $n-2$。同时 $j$ 从 $\rev(1)= \frac{n}{2}$,“循环”到 $\rev(n-2)$ 。

(ii). 第 3 行的判断 if(i < j) 避免了重复交换

(iii).第 5 行的循环的作用就是将 $j$ 从 $\rev(i)$ 变成 $\rev(i+1)$:

首先应当注意到,$i$ 的最低位恰是 $rev(i)$ 的最高位。若 $\rev(i)$ 的最高位是 $0$ 那么 $\rev(i+1)$ 就是 $\rev(i)+ \frac{n}{2}$,否则,$i$ 加上 $1$ 后,最低位将变成 $0$,并且向高一位进 $1$ 。相应的,$\rev(i+1)$ 的最高位应置 $0$(即代码中的 j-=k),并且向低一位"进“ $1$(对应代码中的 k>>=1)。这样从高位往低位检查,遇到 $1$(对应代码中的条件 j>=k)就进位,遇到 $0$ 就退出循环。
3. 我写代码时把第 58 行的 == 写成了 =,结果DEBUG 一个多小时。。。

Version II: 用 C++ 标准库中的 complex<double> 类,代码短一些,但也会慢一些:

 1 #include <bits/stdc++.h>
 2 #define rep(i, l, r) for(int i=l; i<r; i++)
 3 using namespace std;
 4 const double PI(acos(-1));
 5 typedef complex<double> C;
 6 
 7 const int N(1<<17);
 8 int ans[N];
 9 C a[N], b[N];
10 char s[N], t[N];
11 
12 void bit_reverse_swap(C *a, int n){
13     for(int i=1, j=n>>1, k; i<n-1; i++){
14         if(i < j) swap(a[i],a[j]);
15         //tricky
16         for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
17         j+=k;
18     }
19 }
20 
21 void FFT(C* a, int n, int t){
22     bit_reverse_swap(a, n);
23     for(int i=2; i<=n; i<<=1){
24         C wi(cos(t*2*PI/i), sin(t*2*PI/i));
25         for(int j=0; j<n; j+=i){
26             C w(1);
27             for(int k=j, h=i>>1; k<j+h; k++){
28                 C t=w*a[k+h], u=a[k];
29                 a[k]=u+t;
30                 a[k+h]=u-t;
31                 w*=wi;
32             }
33         }
34     }
35     if(t==-1) rep(i, 0, n) a[i]/=n;    //!!error-prone: typo ==/=
36 }
37 
38 int trans(int x){
39     int i=0;
40     for(; x>1<<i; i++);
41     return 1<<i;
42 }
43 
44 int main(){
45     for(; ~scanf("%s%s", s, t); ){
46         int n=strlen(s), m=strlen(t), l=trans(n+m-1);
47         rep(i, 0, n) a[i]=C(s[n-1-i]-'0');
48         rep(i, n, l) a[i]=C(0);
49         rep(i, 0, m) b[i]=C(t[m-1-i]-'0');
50         rep(i, m, l) b[i]=C(0);
51 
52         FFT(a, l, 1), FFT(b, l, 1);
53         rep(i, 0, l) a[i]*=b[i];
54         FFT(a, l, -1);
55         rep(i, 0, l) ans[i]=(int)(a[i].real()+0.5); ans[l]=0;    //error-prone
56         rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
57         int p=l;
58         for(;p && !ans[p]; --p);
59         for(; ~p; putchar(ans[p--]+'0'));    //error-prone
60         puts("");
61     }
62     return 0;
63 }

 

 

posted @ 2016-05-17 21:53  Pat  阅读(898)  评论(0编辑  收藏  举报