A * B Problem Plus
A * B Problem Plus
题目链接:http://acm.split.hdu.edu.cn/showproblem.php?pid=1402
FFT
(FFT的详细证明参见算法导论第三十章)
一个多项式有两种表达方式:
1.系数表示法,系数表示的多项式相乘,时间复杂度为O(n^2);
2.点值表示法,点值表示的多项式相乘,时间复杂度为O(n).
简单的说,FFT能办到的就是将系数表示的多项式转化为点值表示,其时间复杂度为O(nlgn),而将点值表示的多项式转化为系数表示需要IFFT(FFT的逆运算),其形式与FFT相似,时间复杂度也为O(nlgn).
这道题需要用FFT将两个大数转化为点值表示,相乘后再用IFFT将点值表示转化回系数表示,总时间复杂度为O(nlgn).
代码如下:
1 #include<cstdio> 2 #include<cmath> 3 #include<algorithm> 4 #include<cstring> 5 #include<iostream> 6 #define N 200005 7 using namespace std; 8 const double pi=acos(-1.0); 9 struct Complex{ 10 double r,i; 11 Complex(double r=0,double i=0):r(r),i(i){}; 12 Complex operator + (const Complex &rhs){ 13 return Complex(r+rhs.r,i+rhs.i); 14 } 15 Complex operator - (const Complex &rhs){ 16 return Complex(r-rhs.r,i-rhs.i); 17 } 18 Complex operator * (const Complex &rhs){ 19 return Complex(r*rhs.r-i*rhs.i,i*rhs.r+r*rhs.i); 20 } 21 }a[N],b[N],c[N]; 22 char s1[N],s2[N]; 23 int ans[N],n1,n2,len; 24 inline void sincos(double theta,double &p0,double &p1){ 25 p0=sin(theta); 26 p1=cos(theta); 27 } 28 void FFT(Complex P[], int n, int oper){ 29 for(int i=1,j=0;i<n-1;i++){ 30 for(int s=n;j^=s>>=1,~j&s;); 31 if(i<j)swap(P[i],P[j]); 32 } 33 Complex unit_p0; 34 for(int d=0;(1<<d)<n;d++){ 35 int m=1<<d,m2=m*2; 36 double p0=pi/m*oper; 37 sincos(p0,unit_p0.i,unit_p0.r); 38 for(int i=0;i<n;i+=m2){ 39 Complex unit=1; 40 for(int j=0;j<m;j++){ 41 Complex &P1=P[i+j+m],&P2=P[i+j]; 42 Complex t=unit*P1; 43 P1=P2-t; 44 P2=P2+t; 45 unit=unit*unit_p0; 46 } 47 } 48 } 49 if(oper==-1)for(int i=0;i<len;i++)P[i].r/=len; 50 } 51 void Conv(Complex a[],Complex b[],int len){//求卷积 52 FFT(a,len,1);//FFT 53 FFT(b,len,1);//FFT 54 for(int i=0;i<len;++i)c[i]=a[i]*b[i]; 55 FFT(c,len,-1);//IFFT 56 } 57 void init(char *s1,char *s2){ 58 len=1; 59 n1=strlen(s1),n2=strlen(s2); 60 while(len<2*n1||len<2*n2)len<<=1; 61 int idx; 62 for(idx=0;idx<n1;++idx){ 63 a[idx].r=s1[n1-1-idx]-'0'; 64 a[idx].i=0; 65 } 66 while(idx<len){ 67 a[idx].r=a[idx].i=0; 68 idx++; 69 } 70 for(idx=0;idx<n2;++idx){ 71 b[idx].r=s2[n2-1-idx]-'0'; 72 b[idx].i=0; 73 } 74 while(idx<len){ 75 b[idx].r=b[idx].i=0; 76 idx++; 77 } 78 } 79 int main(void){ 80 while(scanf("%s%s",s1,s2)==2){ 81 init(s1,s2); 82 Conv(a,b,len); 83 for(int i=0;i<len+len;++i)ans[i]=0;//93ms 84 //memset(ans,0,sizeof(ans));//140ms 85 int index; 86 for(index=0;index<len||ans[index];++index){ 87 ans[index]+=(c[index].r+0.5); 88 ans[index+1]+=(ans[index]/10); 89 ans[index]%=10; 90 } 91 while(index>0&&!ans[index])index--; 92 for(;index>=0;--index)printf("%d",ans[index]); 93 printf("\n"); 94 } 95 }