UOJ #34. 多项式乘法
#34. 多项式乘法
这是一道模板题。
给你两个多项式,请输出乘起来后的多项式。
输入格式
第一行两个整数 nn 和 mm,分别表示两个多项式的次数。
第二行 n+1n+1 个整数,表示第一个多项式的 00 到 nn 次项系数。
第三行 m+1m+1 个整数,表示第二个多项式的 00 到 mm 次项系数。
输出格式
一行 n+m+1n+m+1 个整数,表示乘起来后的多项式的 00 到 n+mn+m 次项系数。
样例一
input
1 2 1 2 1 2 1
output
1 4 5 2
explanation
(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3(1+2x)⋅(1+2x+x2)=1+4x+5x2+2x3。
限制与约定
0≤n,m≤1050≤n,m≤105,保证输入中的系数大于等于 00 且小于等于 99。
时间限制:1s1s
空间限制:256MB
分析
FFT/NTT模板题。
code
递归:
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 7 using namespace std; 8 const int N = 300100; 9 const double eps = 1e-8; 10 const double pi = acos(-1.0); 11 typedef long long LL; 12 13 struct Complex { 14 double x,y; 15 Complex() {x=0,y=0;} 16 Complex(double xx,double yy) {x=xx,y=yy;} 17 18 }A[N],B[N]; 19 20 Complex operator + (Complex a,Complex b) { 21 return Complex(a.x+b.x,a.y+b.y); 22 } 23 Complex operator - (Complex a,Complex b) { 24 return Complex(a.x-b.x,a.y-b.y); 25 } 26 Complex operator * (Complex a,Complex b) { 27 return Complex(a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y); 28 } 29 30 void FFT(Complex *a,int n,int ty) { 31 if (n==1) return ; 32 Complex a1[n>>1],a2[n>>1]; 33 for (int i=0; i<=n; i+=2) { 34 a1[i>>1] = a[i],a2[i>>1] = a[i+1]; 35 } 36 FFT(a1,n>>1,ty); 37 FFT(a2,n>>1,ty); 38 Complex w1 = Complex(cos(2.0*pi/n),ty*sin(2.0*pi/n)); 39 Complex w = Complex(1.0,0.0); 40 for (int i=0; i<(n>>1); i++) { 41 Complex t = w * a2[i]; 42 a[i+(n>>1)] = a1[i] - t; 43 a[i] = a1[i] + t; 44 w = w * w1; 45 } 46 } 47 int main() { 48 int n,m; 49 scanf("%d%d",&n,&m); 50 for (int i=0; i<=n; ++i) scanf("%lf",&A[i].x); 51 for (int i=0; i<=m; ++i) scanf("%lf",&B[i].x); 52 int fn = 1; 53 while (fn <= n+m) fn <<= 1; 54 FFT(A,fn,1); 55 FFT(B,fn,1); 56 for (int i=0; i<=fn; ++i) 57 A[i] = A[i] * B[i]; 58 FFT(A,fn,-1); 59 for (int i=0; i<=n+m; ++i) 60 printf("%d ",(int)(A[i].x/fn+0.5)); 61 return 0; 62 }
非递归
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 7 using namespace std; 8 const int N = 300100; 9 const double eps = 1e-8; 10 const double Pi = acos(-1.0); 11 typedef long long LL; 12 13 struct Complex { 14 double x,y; 15 Complex() {x=0,y=0;} 16 Complex(double xx,double yy) {x=xx,y=yy;} 17 18 }A[N],B[N]; 19 20 Complex operator + (Complex a,Complex b) { 21 return Complex(a.x+b.x,a.y+b.y); 22 } 23 Complex operator - (Complex a,Complex b) { 24 return Complex(a.x-b.x,a.y-b.y); 25 } 26 Complex operator * (Complex a,Complex b) { 27 return Complex(a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y); 28 } 29 void FFT(Complex *a,int n,int ty) { 30 // 按递归时最底层的顺序翻转 31 for (int i=0,j=0; i<n; ++i) { 32 if (i < j) swap(a[i],a[j]); 33 for (int k=n>>1; (j^=k)<k; k>>=1); 34 } 35 // 当前正在求次数界为m的多项式。m=1的已经在上面求出来了,所以现在在求次数界为2的多项式。 36 for (int m=2; m<=n; m<<=1) { 37 Complex w1 = Complex(cos(2*Pi/m),ty*sin(2*Pi/m)); 38 for (int i=0; i<n; i+=m) { // 当前求的多项式下标为[i,i+m-1] 39 Complex w = Complex(1,0); 40 for (int k=0; k<(m>>1); ++k) { // 由[i,i+(m/2)-1]和[i+(m/2),i+m-1] 求出[i,i+m-1]的多项式 41 Complex t = w * a[i+k+(m>>1)]; 42 Complex u = a[i+k]; 43 a[i+k] = u + t; 44 a[i+k+(m>>1)] = u - t; 45 w = w * w1; 46 } 47 } 48 } 49 } 50 int main() { 51 int n,m; 52 scanf("%d%d",&n,&m); 53 for (int i=0; i<=n; ++i) scanf("%lf",&A[i].x); 54 for (int i=0; i<=m; ++i) scanf("%lf",&B[i].x); 55 int fn = 1; 56 while (fn <= n+m) fn <<= 1; 57 FFT(A,fn,1); 58 FFT(B,fn,1); 59 for (int i=0; i<=fn; ++i) 60 A[i] = A[i] * B[i]; 61 FFT(A,fn,-1); 62 for (int i=0; i<=n+m; ++i) 63 printf("%d ",(int)(A[i].x/fn+0.5)); 64 return 0; 65 }
NTT
注意代码中的longlong,取模。
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 7 using namespace std; 8 9 typedef long long LL; 10 const int N = 300100; 11 const int P = 998244353; 12 int A[N],B[N]; 13 14 int ksm(int a,int b) { 15 int ans = 1; 16 while (b) { 17 if (b & 1) ans = (1ll * ans * a) % P; 18 a = (1ll * a * a) % P; 19 b >>= 1; 20 } 21 return ans % P; 22 } 23 void NTT(int *a,int n,int ty) { 24 // 按递归时最底层的顺序翻转 25 for (int i=0,j=0; i<n; ++i) { 26 if (i < j) swap(a[i],a[j]); 27 for (int k=n>>1; (j^=k)<k; k>>=1); 28 } 29 // 当前正在求次数界为m的多项式。m=1的已经在上面求出来了,所以现在在求次数界为2的多项式。 30 for (int m=2; m<=n; m<<=1) { 31 int w1 = ksm(3,(P-1)/m); 32 if (ty == -1) w1 = ksm(w1,P-2); 33 for (int i=0; i<n; i+=m) { // 当前求的多项式下标为[i,i+m-1] 34 int w = 1; 35 for (int k=0; k<(m>>1); ++k) { // 由[i,i+(m/2)-1]和[i+(m/2),i+m-1] 求出[i,i+m-1]的多项式 36 int t = 1ll * w * a[i+k+(m>>1)] % P; 37 int u = a[i+k]; 38 a[i+k] = (u + t) % P; 39 a[i+k+(m>>1)] = (u - t + P) % P; 40 w = 1ll * w * w1 % P; 41 } 42 } 43 } 44 } 45 int main() { 46 int n,m; 47 scanf("%d%d",&n,&m); 48 for (int i=0; i<=n; ++i) scanf("%d",&A[i]),A[i] = (A[i] + P) % P; 49 for (int i=0; i<=m; ++i) scanf("%d",&B[i]),B[i] = (B[i] + P) % P; 50 int len = 1; 51 while (len <= n+m) len <<= 1; 52 NTT(A,len,1); 53 NTT(B,len,1); 54 for (int i=0; i<len; ++i) A[i] = 1ll * A[i] * B[i] % P; 55 NTT(A,len,-1); 56 int inv = ksm(len,P-2); 57 for (int i=0; i<len; ++i) A[i] = 1ll * A[i] * inv % P; 58 for (int i=0; i<=n+m; ++i) printf("%d\n",A[i]); 59 return 0; 60 }