【uoj34】 多项式乘法
http://uoj.ac/problem/34 (题目链接)
题意
求两个多项式的乘积
Solution
挂个FFT板子。
细节
FFT因为要满足$n$是$2$的幂,所以注意数组大小。
代码
// uoj34 #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<complex> #include<cstdio> #include<cmath> #define LL long long #define inf 2147483640 #define Pi acos(-1.0) #define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout); using namespace std; typedef complex<double> E; const int maxn=300010; E a[maxn],b[maxn]; int n,m; namespace FFT { int rev[maxn],L; void DFT(E *a,int f) { for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1;i<n;i<<=1) { E wn(cos(Pi/i),f*sin(Pi/i)); for (int p=i<<1,j=0;j<n;j+=p) { E w(1,0); for (int k=0;k<i;k++,w*=wn) { E x=a[j+k],y=w*a[j+k+i]; a[j+k]=x+y;a[j+k+i]=x-y; } } } if (f==-1) for (int i=0;i<n;i++) a[i].real()/=n; } void main() { m=n+m; for (n=1;n<=m;n<<=1) L++; //一定是<=,因为这里的m是最高次幂 for (int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<(L-1)); DFT(a,1);DFT(b,1); for (int i=0;i<n;i++) a[i]=a[i]*b[i]; DFT(a,-1); } } int main() { scanf("%d%d",&n,&m); for (int i=0,x;i<=n;i++) scanf("%d",&x),a[i]=x; for (int i=0,x;i<=m;i++) scanf("%d",&x),b[i]=x; FFT::main(); for (int i=0;i<=m;i++) printf("%d ",(int)(a[i].real()+0.5)); return 0; }
Solution
${NTT}$,适用于对一些形如 ${p=C*2^k+1}$的数取模,且${2^k>n}$(当然也可以将不取模但结果不会超过某个范围视作取模)的多项式乘法问题。
一些常见的${NTT}$模数:
${998244353=119*2^{23}+1}$,原根为${3}$。
${1004535809=479*2^{21}+1}$,原根为${3}$。
${15*2^{112}+1}$,原根为${1111}$。
详情请见Xlightgod的博客
代码
// uoj34 #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<cstdio> #include<cmath> #define LL long long #define inf 2147483640 #define MOD 998244353 #define Pi acos(-1.0) #define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout); using namespace std; const int maxn=300010; int a[maxn],b[maxn],rev[maxn],n,m,L; int power(int a,int b) { int res=1; while (b) { if (b&1) res=1LL*res*a%MOD; a=1LL*a*a%MOD;b>>=1; } return res; } void NTT(int *a,int f) { for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1;i<n;i<<=1) { int gn=power(3,(MOD-1)/(i<<1)); //这里除的是i<<1 for (int p=i<<1,j=0;j<n;j+=p) { int g=1; for (int k=0;k<i;k++,g=1LL*g*gn%MOD) { int x=a[k+j],y=1LL*g*a[k+j+i]%MOD; a[k+j]=(x+y)%MOD;a[k+j+i]=(x-y+MOD)%MOD; } } } if (f==-1) { int ev=power(n,MOD-2);reverse(a+1,a+n); //reverse的是[1,n) for (int i=0;i<n;i++) a[i]=1LL*a[i]*ev%MOD; } } int main() { scanf("%d%d",&n,&m); for (int i=0;i<=n;i++) scanf("%d",&a[i]); for (int i=0;i<=m;i++) scanf("%d",&b[i]); m=n+m; for (n=1;n<=m;n<<=1) L++; for (int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<(L-1)); NTT(a,1);NTT(b,1); for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%MOD; NTT(a,-1); for (int i=0;i<=m;i++) printf("%d ",a[i]); return 0; }
Solution3
听说还有任意模数的${NTT}$,比如说对${1000000007}$取模,那么这显然是不能直接${NTT}$的,直接${FFT}$转成整型取模的时候又会爆LL。我用的是毛爷爷的做法,把一个数${x}$拆成${x=a*M+b}$,${M}$是模数的算术平方根。这样就能避免爆LL了。
具体见上面那个链接:Xlightgod。
代码
// uoj34 #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<complex> #include<cstdio> #include<cmath> #include<queue> #define LL long long #define inf 1ll<<60 #define MOD 1000000007 #define M 32768 #define Pi acos(-1.0) #define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout); using namespace std; typedef complex<double> E; const int maxn=300010; E a[maxn],b[maxn],c[maxn],d[maxn],A[maxn],B[maxn],C[maxn]; int n,m,L,rev[maxn]; void FFT(E *a,int f) { for (int i=0;i<n;i++) if (rev[i]>i) swap(a[i],a[rev[i]]); for (int i=1;i<n;i<<=1) { E wn(cos(Pi/i),f*sin(Pi/i)); for (int p=i<<1,j=0;j<n;j+=p) { E w(1,0); for (int k=0;k<i;k++,w*=wn) { E x=a[j+k],y=w*a[j+k+i]; a[j+k]=x+y;a[j+k+i]=x-y; } } } if (f==-1) for (int i=0;i<n;i++) a[i].real()=a[i].real()/n+0.5; //这里的0.5一定要除了再加上去 } int main() { scanf("%d%d",&n,&m); for (int x,i=0;i<=n;i++) { scanf("%d",&x); a[i]=x>>15;b[i]=x&(M-1); } for (int x,i=0;i<=m;i++) { scanf("%d",&x); c[i]=x>>15;d[i]=x&(M-1); } m=n+m; for (n=1,L=-1;n<=m;n<<=1) L++; for (int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<L); FFT(a,1);FFT(b,1);FFT(c,1);FFT(d,1); for (int i=0;i<n;i++) { A[i]=a[i]*c[i]; B[i]=a[i]*d[i]+b[i]*c[i]; C[i]=b[i]*d[i]; } FFT(A,-1);FFT(B,-1);FFT(C,-1); for (int i=0;i<=m;i++) { LL x=(LL)A[i].real()%MOD,y=(LL)B[i].real()%MOD,z=(LL)C[i].real()%MOD; printf("%lld ",((x<<30)+(y<<15)+z)%MOD); } return 0; }
This passage is made by MashiroSky.