HDU 1402 大数乘法 FFT、NTT
A * B Problem Plus
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)
Total Submission(s): 26874 Accepted Submission(s): 7105
Problem Description
Calculate A * B.
Input
Each line will contain two integers A and B. Process to end of file.
Note: the length of each integer will not exceed 50000.
Note: the length of each integer will not exceed 50000.
Output
For each case, output A * B in one line.
Sample Input
1
2
1000
2
Sample Output
2
2000
解析 FFT入门题 一个整数可以看成是一个多项式 1234=4*10^0+3*10^1+2*10^2+1*10^3
数据好像有前导零 要处理一下。。
这个是结构体实现的 跑的比较快的代码
//#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<cstring> #include<cmath> #include<cstdlib> #include<iostream> #include<algorithm> #define pb push_back #define mp make_pair #define fi first #define se second #define all(a) (a).begin(), (a).end() #define fillchar(a, x) memset(a, x, sizeof(a)) #define huan printf("\n"); using namespace std; typedef long long ll; const ll maxn=3e5+20,inf=0x3f3f3f3f; const ll mod=1e9+7; const double PI = acos(-1.0); //复数结构体 c++自带复数容器 用万能头的话会重名 struct complex { double r,i; complex(double _r = 0.0,double _i = 0.0) { r = _r; i = _i; } complex operator +(const complex &b) { return complex(r+b.r,i+b.i); } complex operator -(const complex &b) { return complex(r-b.r,i-b.i); } complex operator *(const complex &b) { return complex(r*b.r-i*b.i,r*b.i+i*b.r); } }; /* * 进行FFT和IFFT前的反转变换。 * 位置i和 (i二进制反转后位置)互换 * len必须取2的幂 */ void change(complex y[],int len) { int i,j,k; for(i = 1, j = len/2;i < len-1; i++) { if(i < j)swap(y[i],y[j]); //交换互为小标反转的元素,i<j保证交换一次 //i做正常的+1,j左反转类型的+1,始终保持i和j是反转的 k = len/2; while( j >= k) { j -= k; k /= 2; } if(j < k) j += k; } } /* * 做FFT * len必须为2^k形式, * on==1时是DFT,on==-1时是IDFT */ void fft(complex y[],int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h)); for(int j = 0;j < len;j+=h) { complex w(1,0); for(int k = j;k < j+h/2;k++) { complex u = y[k]; complex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0;i < len;i++) y[i].r /= len; } char num1[maxn], num2[maxn]; complex x1[maxn], x2[maxn]; int ans[maxn]; int main() { while(scanf("%s%s",num1,num2)!=EOF) { memset(ans, 0, sizeof(ans)); int len = 1, len1 = strlen(num1), len2 = strlen(num2); while(len<len1+len2+1) len <<= 1; for(int i = 0; i < len1; i++) x1[len1-1-i] = complex((double)(num1[i]-'0'), 0); for(int i = len1; i < len; i++) x1[i] = complex(0, 0); fft(x1, len, 1); for(int i = 0; i < len2; i++) x2[len2-1-i] = complex((double)(num2[i]-'0'), 0); for(int i = len2; i < len; i++) x2[i] = complex(0, 0); fft(x2, len, 1); for(int i = 0; i < len; i++) x1[i] = x1[i] * x2[i]; fft(x1, len, -1); for(int i = 0; i < len; i++) ans[i] = (int)(x1[i].r+0.5); for(int i = 1; i < len; i++) { ans[i] += ans[i-1]/10; ans[i-1] %= 10; } while(len>0 && !ans[len]) len--; for(int i = len; i >= 0; i--) printf("%c", ans[i]+'0'); puts(""); } return 0; }
这个是complex实现的 慢一点 这个初始化占内存而且费时 因为把根都提前存好了
#include <bits/stdc++.h> #define pb push_back #define mp make_pair #define fi first #define se second #define all(a) (a).begin(), (a).end() #define fillchar(a, x) memset(a, x, sizeof(a)) #define huan printf("\n"); #define debug(a,b) cout<<a<<" "<<b<<" "; using namespace std; typedef long long ll; const ll maxn=2e5+20,inf=0x3f3f3f3f; const ll mod=1e9+7; const double PI = acos(-1); typedef complex <double> cp; char sa[maxn], sb[maxn]; int n = 1, lena, lenb, res[maxn]; cp a[maxn], b[maxn], omg[maxn], inv[maxn]; void init(){ for(int i = 0; i < n; i++){ a[i]=b[i]=cp(0,0); omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n)); inv[i] = conj(omg[i]); //conj共轭复数 } memset(res,0,sizeof(res)); } void fft(cp *a, cp *omg){ int lim = 0; while((1 << lim) < n) lim++; for(int i = 0; i < n; i++){ int t = 0; for(int j = 0; j < lim; j++) if((i >> j) & 1) t |= (1 << (lim - j - 1)); if(i < t) swap(a[i], a[t]); // i < t 的限制使得每对点只被交换一次(否则交换两次相当于没交换) } for(int l = 2; l <= n; l *= 2){ int m = l / 2; for(cp *p = a; p != a + n; p += l) for(int i = 0; i < m; i++){ cp t = omg[n / l * i] * p[i + m]; p[i + m] = p[i] - t; p[i] += t; } } } int main(){ while(scanf("%s%s", sa, sb)!=EOF){ n=1,lena = strlen(sa), lenb = strlen(sb); while(n < lena + lenb) n *= 2; //n必须是2的次幂' init(); for(int i = 0; i < lena; i++) a[i].real(sa[lena - 1 - i] - '0'); for(int i = 0; i < lenb; i++) b[i].real(sb[lenb - 1 - i] - '0'); fft(a, omg); fft(b, omg); for(int i = 0; i < n; i++) a[i] *= b[i]; fft(a, inv); for(int i = 0; i < n; i++){ res[i] += floor(a[i].real() / n + 0.5); res[i + 1] += res[i] / 10; res[i] %= 10; } while(n>0&&!res[n])n--; //抛去前导零 for(int i = n; i >= 0; i--) putchar('0' + res[i]); puts(""); } return 0; }
这是我改编版本一的代码
//#pragma comment(linker, "/STACK:1024000000,1024000000") #include<bits/stdc++.h> #define pb push_back #define mp make_pair #define fi first #define se second #define all(a) (a).begin(), (a).end() #define fillchar(a, x) memset(a, x, sizeof(a)) #define huan printf("\n"); using namespace std; typedef long long ll; typedef complex<double> complexd; const ll maxn=3e5+20,inf=0x3f3f3f3f; const ll mod=1e9+7; const double PI = acos(-1.0); /* * 进行FFT和IFFT前的反转变换。 * 位置i和 (i二进制反转后位置)互换 * len必须取2的幂 */ void change(complexd *y,int len) { int i,j,k; for(i = 1, j = len/2; i < len-1; i++) { if(i < j) swap(y[i],y[j]); //交换互为小标反转的元素,i<j保证交换一次 //i做正常的+1,j左反转类型的+1,始终保持i和j是反转的 k = len/2; while( j >= k) { j -= k; k /= 2; } if(j < k) j += k; } } /* * 做FFT * len必须为2^k形式, * on==1时是DFT,on==-1时是IDFT */ void fft(complexd *y,int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { complexd wn(cos(-on*2*PI/h),sin(-on*2*PI/h)); for(int j = 0; j < len; j+=h) { complexd w(1,0); for(int k = j; k < j+h/2; k++) { complexd u = y[k]; complexd t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0; i < len; i++) y[i]=complexd(y[i].real()/len,y[i].imag()); } char num1[maxn], num2[maxn]; complexd x1[maxn], x2[maxn]; int ans[maxn]; int main() { while(scanf("%s%s",num1,num2)!=EOF) { memset(ans, 0, sizeof(ans)); int len = 1, len1 = strlen(num1), len2 = strlen(num2); while(len<len1+len2+1) len <<= 1; for(int i = 0; i < len1; i++) x1[len1-1-i] = complexd((double)(num1[i]-'0'), 0); for(int i = len1; i < len; i++) x1[i] = complexd(0, 0); fft(x1, len, 1); for(int i = 0; i < len2; i++) x2[len2-1-i] = complexd((double)(num2[i]-'0'), 0); for(int i = len2; i < len; i++) x2[i] = complexd(0, 0); fft(x2, len, 1); for(int i = 0; i < len; i++) x1[i] = x1[i] * x2[i]; fft(x1, len, -1); for(int i = 0; i < len; i++) ans[i] = (int)(x1[i].real()+0.5); for(int i = 1; i < len; i++) { ans[i] += ans[i-1]/10; ans[i-1] %= 10; } while(len>0 && !ans[len]) len--; for(int i = len; i >= 0; i--) printf("%c", ans[i]+'0'); puts(""); } return 0; }
NTT 模意义下的FFT
#include<cmath> #include<ctime> #include<cstdio> #include<cstring> #include<cstdlib> #include<iostream> #include<algorithm> #include<iomanip> #include<vector> #include<string> #include<bitset> #include<queue> #include<map> #include<set> using namespace std; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch<='9'&&ch>='0'){x=10*x+ch-'0';ch=getchar();} return x*f; } void print(int x) {if(x<0)putchar('-'),x=-x;if(x>=10)print(x/10);putchar(x%10+'0');} const int N=300100,P=998244353; inline int qpow(int x,int y) { int res(1); while(y) { if(y&1) res=1ll*res*x%P; x=1ll*x*x%P; y>>=1; } return res; } int r[N]; void ntt(int *x,int lim,int opt) { register int i,j,k,m,gn,g,tmp; for(i=0;i<lim;++i) if(r[i]<i) swap(x[i],x[r[i]]); for(m=2;m<=lim;m<<=1) { k=m>>1; gn=qpow(3,(P-1)/m); for(i=0;i<lim;i+=m) { g=1; for(j=0;j<k;++j,g=1ll*g*gn%P) { tmp=1ll*x[i+j+k]*g%P; x[i+j+k]=(x[i+j]-tmp+P)%P; x[i+j]=(x[i+j]+tmp)%P; } } } if(opt==-1) { reverse(x+1,x+lim); register int inv=qpow(lim,P-2); for(i=0;i<lim;++i) x[i]=1ll*x[i]*inv%P; } } int A[N],B[N],C[N]; char a[N],b[N]; int main() { register int i,lim(1),n; while(scanf("%s%s",a,b)!=EOF) { memset(C,0,sizeof(C)); memset(A,0,sizeof(A)); memset(B,0,sizeof(B)); n=strlen(a); for(i=0;i<n;++i) A[i]=a[n-i-1]-'0'; while(lim<(n<<1)) lim<<=1; n=strlen(b); for(i=0;i<n;++i) B[i]=b[n-i-1]-'0'; while(lim<(n<<1)) lim<<=1; for(i=0;i<lim;++i) r[i]=(i&1)*(lim>>1)+(r[i>>1]>>1); ntt(A,lim,1);ntt(B,lim,1); for(i=0;i<lim;++i) C[i]=1ll*A[i]*B[i]%P; ntt(C,lim,-1); int len(0); for(i=0;i<lim;++i) { if(C[i]>=10) len=i+1, C[i+1]+=C[i]/10,C[i]%=10; if(C[i]) len=max(len,i); } while(C[len]>=10) C[len+1]+=C[len]/10,C[len]%=10,len++; for(i=len;~i;--i) putchar(C[i]+'0'); puts(""); } return 0; }
FFT原理学习 https://zhuanlan.zhihu.com/p/40505277?utm_source=qq&utm_medium=social&utm_oi=854653490251829248