利用傅立叶变换可以把大数乘法时间控制在:O(n*lgn),首先把整数a1a2...an看作多项式f(x) = a1x + a2x^2 + .... an x^n(其中x取10),把整数乘法转换为多项式乘法,多项式一般乘法也要O(n*n)时间,而多项式点值表示法的乘法只需要O(n)的时间,于是就有了条折返路线,如下:
1.首先把f(x)扩展为2n次多项式,取2n个点求多项式值 (O(nlgn))
2.在对这两的多项式的点值表示进行乘积 (O(n))
3.在对所得的点值进行插值得到结果多项式的系数 (O(nlgn))
过程就是:系数 -> 点值 -> 系数
第一步如果随意取2n个点进行求值,也要O(n*n),而如果利用单位复根exp(iu)的诸多性质就可以在O(nlgn)时间内求的点值对,也可以在O(nlgn)时间内插值,具体可以参考《计算方法引论》中的证明。
代码:
1 #include <cstdio> 2 #include <cstring> 3 #include <cmath> 4 5 using namespace std; 6 7 #define MAXN 500005 8 #define PI acos(-1.0) 9 10 struct vir 11 { 12 double r,i; 13 14 vir(double r = 0.0,double i = 0.0) 15 { 16 this -> r = r; 17 this -> i = i; 18 } 19 vir operator +(const vir & b) 20 { 21 return vir(r + b.r,i + b.i); 22 } 23 vir operator -(const vir & b) 24 { 25 return vir(r - b.r,i - b.i); 26 } 27 vir operator *(const vir & b) 28 { 29 return vir(r * b.r - i * b.i,r * b.i + i * b.r); 30 } 31 }; 32 33 char stra[MAXN],strb[MAXN]; 34 vir vira[MAXN],virb[MAXN]; 35 int ans[MAXN]; 36 37 int init(char *stra,char *strb,vir *vira,vir *virb) 38 { 39 int lena = strlen(stra); 40 int lenb = strlen(strb); 41 int len = 1; 42 while(len < 2 * lena || len < 2 * lenb) len <<= 1; //位数扩展2倍 43 44 int i; 45 for(i = 0;i < lena;i++) 46 { 47 vira[i].r = stra[lena-i-1] - '0'; 48 vira[i].i = 0.0; 49 } 50 for(;i < len;i++) 51 { 52 vira[i].r = vira[i].i = 0.0; 53 } 54 for(i = 0;i < lenb;i++) 55 { 56 virb[i].r = strb[lenb-i-1] - '0'; 57 virb[i].i = 0.0; 58 } 59 for(;i < len;i++) 60 { 61 virb[i].r = virb[i].i = 0.0; 62 } 63 return len; 64 } 65 66 void bitReverseSort(vir * v,int len) //按逆序数排序 67 { 68 for(int i = 0;i < len;i++) 69 { 70 int n = i; 71 int revn = 0; 72 for(int j = 1;j < len;j <<= 1) 73 { 74 revn = (revn << 1) | (n & 1); 75 n >>= 1; 76 } 77 if(revn > i) 78 { 79 vir t = v[i]; 80 v[i] = v[revn]; 81 v[revn] = t; 82 } 83 } 84 } 85 86 void FFT(vir *v,int len,int flag) //flag为1求值,-1插值 87 { 88 bitReverseSort(v,len); 89 90 for(int r = 1;r < len;r <<= 1) 91 { 92 vir wp(cos(PI / r),sin(-flag * PI / r)); 93 vir w(1,0); 94 for(int s = 0;s < r;s++) 95 { 96 for(int k = s;k < len;k += 2 * r) 97 { 98 int l = k + r; 99 vir d = w * v[l]; 100 v[l] = v[k] - d; 101 v[k] = v[k] + d; 102 } 103 w = w * wp; 104 } 105 } 106 107 if(flag == -1) 108 { 109 for(int i = 0;i < len;i++) 110 { 111 v[i].r /= len; 112 } 113 } 114 } 115 116 void mul(vir * vira,vir * virb,int len) 117 { 118 for(int i = 0;i < len;i++) 119 { 120 vira[i] = vira[i] * virb[i]; 121 } 122 } 123 124 void output(vir *vira,int *ans,int len) 125 { 126 for(int i = 0;i < len;i++) 127 { 128 ans[i] = vira[i].r + 0.5; 129 } 130 for(int i=0;i<len;i++) 131 { 132 ans[i+1] += ans[i]/10; 133 ans[i] %= 10; 134 } 135 int pos = len - 1; 136 while(!ans[pos] && pos > 0)pos--; 137 for(int i = pos;i >= 0;i--) 138 { 139 printf("%d",ans[i]); 140 } 141 printf("\n"); 142 } 143 144 int main() 145 { 146 while(scanf("%s%s",stra,strb) != EOF) 147 { 148 int len = init(stra,strb,vira,virb); 149 FFT(vira,len,1); 150 FFT(virb,len,1); 151 mul(vira,virb,len); 152 FFT(vira,len,-1); 153 output(vira,ans,len); 154 } 155 return 0; 156 } 157 158 //9416236 2013-10-26 10:13:27 Accepted 1402 265MS 16512K 2836 B G++ 超级旅行者
参考资料:
《计算方法引论》 第三版 徐萃薇
《算法导论》