UOJ #34 多项式乘法

题目链接:多项式乘法

  保存一发FFT与NTT板子。

  学习链接:从多项式乘法到快速傅里叶变换 FFT NTT

  注意差值回来的时候不取反也是可以的,只不过需要把数组\(reverse\)一下(根据单位复数根的性质应该不难理解)

  代码(FFT):

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<complex>
#define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
#define C complex<double>
#define maxn 300010
#define pi (acos(-1))

using namespace std;
typedef long long llg;

int n,m,L,R[maxn];
C a[maxn],b[maxn];

int getint(){
	int w=0;bool q=0;
	char c=getchar();
	while((c>'9'||c<'0')&&c!='-') c=getchar();
	if(c=='-') c=getchar(),q=1;
	while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar();
	return q?-w:w;
}

void fft(C *a){
	for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]);
	for(int i=1;i<n;i<<=1){
		C wn(cos(pi/i),sin(pi/i)),x,y;
		for(int j=0;j<n;j+=(i<<1)){
			C w(1,0);
			for(int k=0;k<i;k++,w*=wn){
				x=a[j+k]; y=w*a[j+i+k];
				a[j+k]=x+y; a[j+i+k]=x-y;
			}
		}
	}
}

int main(){
	File("a");
	n=getint(); m=getint();
	for(int i=0;i<=n;i++) a[i]=getint();
	for(int i=0;i<=m;i++) b[i]=getint();
	m+=n; for(n=1;n<=m;n<<=1) L++;
	for(int i=0;i<n;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	fft(a); fft(b);
	for(int i=0;i<n;i++) a[i]*=b[i];
	fft(a); reverse(a+1,a+n);
	for(int i=0;i<=m;i++) printf("%d ",(int)round(a[i].real()/n));
	return 0;
}

  代码(NTT):

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<complex>
#define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
#define maxn 300010
#define mod 998244353

using namespace std;
typedef long long llg;

const int g=3;
int n,m,L,R[maxn],N;
int a[maxn],b[maxn];

int getint(){
	int w=0;bool q=0;
	char c=getchar();
	while((c>'9'||c<'0')&&c!='-') c=getchar();
	if(c=='-') c=getchar(),q=1;
	while(c>='0'&&c<='9') w=w*10+c-'0',c=getchar();
	return q?-w:w;
}

int qpow(int x,int y){
	int s=1;
	while(y){
		if(y&1) s=1ll*s*x%mod;
		x=1ll*x*x%mod; y>>=1;
	}
	return s;
}

void ntt(int *a){
	for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]);
	for(int i=1;i<n;i<<=1){
		int gn=qpow(g,(mod-1)/(i<<1)),x,y;
		for(int j=0;j<n;j+=(i<<1)){
			int g=1;
			for(int k=0;k<i;k++,g=1ll*g*gn%mod){
				x=a[j+k]; y=1ll*g*a[j+i+k]%mod;
				a[j+k]=x+y; if(a[j+k]>=mod) a[j+k]-=mod;
				a[j+i+k]=x-y; if(x<y) a[j+i+k]+=mod;
			}
		}
	}
}

int main(){
	File("a");
	n=getint(); m=getint();
	for(int i=0;i<=n;i++) a[i]=getint();
	for(int i=0;i<=m;i++) b[i]=getint();
	m+=n; for(n=1;n<=m;n<<=1) L++;
	for(int i=0;i<n;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	ntt(a); ntt(b); N=qpow(n,mod-2);
	for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
	ntt(a); reverse(a+1,a+n);
	for(int i=0;i<=m;i++) printf("%d ",1ll*a[i]*N%mod);
	return 0;
}
posted @ 2017-02-06 20:34  lcf2000  阅读(227)  评论(0编辑  收藏  举报