常数比较小码量不大的 MTT(4次FFT)/任意模数多项式乘法

先根据 Prean 的题解 写出一个常数较小的 5 次 FFT 写法。

inline ll get(const double x){return (ll(x+0.5))%mod;}
inline void MTT(const int *A,const int *B,int *C,int n,int m){
	static Comp F[maxn],G[maxn],T[maxn];len=1;while(len<=n+m)len<<=1;init(len);
	memset(F+n+1,0,(len-n-1)*sizeof(Comp)),memset(G+n+1,0,(len-n-1)*sizeof(Comp));
	memset(T+m+1,0,(len-m-1)*sizeof(Comp));
	for(reg i=0;i<=n;++i)F[i]=Comp(A[i]&32767,0),G[i]=Comp(A[i]>>15,0);
	for(reg i=0;i<=m;++i)T[i]=Comp(B[i]&32767,B[i]>>15);
	FFT(F,len),FFT(G,len),FFT(T,len);
	for(reg i=0;i<len;++i)F[i]*=T[i];
	for(reg i=0;i<len;++i)G[i]*=T[i];
	IFFT(F,len),IFFT(G,len);
	for(reg i=0;i<len;++i)C[i]=(get(F[i].real)+(get(F[i].imag+G[i].real)<<15)+(get(G[i].imag)<<30))%mod;
}


又发现

于是根据 Kewth 的题解,写出了 4 次 FFT 版本。

inline void FFT2(Comp *a,Comp *b,int n){
	for(reg i=0;i<n;++i)a[i].imag=b[i].real;
	FFT(a,n);
	for(reg i=0;i<n;++i)b[i]=conj(a[i?n-i:0]);
	for(reg i=0;i<n;++i){
		Comp p=a[i],q=b[i];
		a[i]=(p+q)*0.5,b[i]=(q-p)*0.5*Comp(0,1);
	}
}
int mod;
inline ll get(const double x){return (ll(x+0.5))%mod;}
inline void MTT(const int *A,const int *B,int *C,int n,int m){
	static Comp F[maxn],G[maxn],T[maxn];len=1;while(len<=n+m)len<<=1;init(len);
	memset(F+n+1,0,(len-n-1)*sizeof(Comp)),memset(G+n+1,0,(len-n-1)*sizeof(Comp));
	memset(T+m+1,0,(len-m-1)*sizeof(Comp));
	for(reg i=0;i<=n;++i)F[i]=Comp(A[i]&32767,0),G[i]=Comp(A[i]>>15,0);
	for(reg i=0;i<=m;++i)T[i]=Comp(B[i]&32767,B[i]>>15);
	FFT2(F,G,len),FFT(T,len);
	for(reg i=0;i<len;++i)F[i]*=T[i];
	for(reg i=0;i<len;++i)G[i]*=T[i];
	IFFT(F,len),IFFT(G,len);
	for(reg i=0;i<len;++i)C[i]=(get(F[i].real)+(get(F[i].imag+G[i].real)<<15)+(get(G[i].imag)<<30))%mod;
}

优化一下常数,最终卡进了最优第二页(最底端)

#include<bits/stdc++.h>
#define EL puts("Elaina")
#define reg register int
typedef long long ll;
using namespace std;
namespace IO{
	const int siz=1<<18;char buf[siz],*p1,*p2;
	inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,siz,stdin),p1==p2)?EOF:*p1++;}
	inline int read(){
		int x=0;char ch=getc();
		while(!isdigit(ch))ch=getc();
		while(isdigit(ch))x=x*10+(ch^48),ch=getc();
		return x;
	}
}using IO::read;
const int maxn=4e5+3;const double Pi=acos(-1);
struct Comp{
	double real,imag;
	Comp(){}
	Comp(double a,double b){real=a,imag=b;}
	inline Comp operator +(const Comp &a)const{return Comp(real+a.real,imag+a.imag);}
	inline Comp operator -(const Comp &a)const{return Comp(real-a.real,imag-a.imag);}
	inline Comp operator *(const Comp &a)const{return Comp(real*a.real-imag*a.imag,imag*a.real+real*a.imag);}
	inline Comp operator *(const double &a)const{return Comp(real*a,imag*a);}
	inline void operator+=(const Comp &a){real+=a.real,imag+=a.imag;}
	inline void operator*=(const Comp &a){(*this)=(*this)*a;}
}w[maxn];
int rev[maxn],len;
inline void init(int len){
	int mid=len>>1;w[mid]=Comp(1,0);
	for(reg i=0;i<len;++i){rev[i]=rev[i>>1]>>1;if(i&1)rev[i]|=mid;}
	for(reg i=1;i<mid;++i)w[i+mid]=Comp(cos(i*2*Pi/len),sin(i*2*Pi/len));
	for(reg i=mid-1;i>0;--i)w[i]=w[i<<1];
}
inline void FFT(Comp *a,int n){
	Comp t;for(reg i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(reg u=2,v=1;u<=n;v=u,u<<=1)for(reg L=0;L<n;L+=u)
		for(reg p=L,x=v;p<L+v;++p,++x)t=w[x]*a[p+v],a[p+v]=a[p]-t,a[p]+=t;
}
inline void IFFT(Comp *a,int n){
	FFT(a,n),reverse(a+1,a+n);for(reg i=0;i<n;++i)a[i].real/=n,a[i].imag/=n;
}
int mod;
inline ll get(const double x){return (ll(x+0.5))%mod;}
inline void MTT(const int *A,const int *B,int *C,int n,int m){
	static Comp F[maxn],G[maxn],T[maxn];len=1;while(len<=n+m)len<<=1;init(len);
	memset(F+n+1,0,(len-n-1)*sizeof(Comp)),memset(T+m+1,0,(len-m-1)*sizeof(Comp));
	for(reg i=0;i<=n;++i)F[i]=Comp(A[i]&32767,A[i]>>15);
	for(reg i=0;i<=m;++i)T[i]=Comp(B[i]&32767,B[i]>>15);
	FFT(F,len),FFT(T,len),G[0]=F[0];
	for(reg i=1;i<len;++i)G[i]=F[len-i];
	for(reg i=0;i<len;++i){
		Comp p=F[i],q=G[i];
		F[i]=Comp(p.real+q.real,p.imag-q.imag)*T[i]*0.5;
		G[i]=Comp(p.imag+q.imag,q.real-p.real)*T[i]*0.5;
	}
	IFFT(F,len),IFFT(G,len);
	for(reg i=0;i<len;++i)C[i]=(get(F[i].real)+(get(F[i].imag+G[i].real)<<15)+(get(G[i].imag)<<30))%mod;
}
int n,m,F[maxn],G[maxn];
inline void MyDearMoments(){
	n=read(),m=read(),mod=read();
	for(reg i=0;i<=n;++i)F[i]=read();
	for(reg i=0;i<=m;++i)G[i]=read();
	MTT(F,G,F,n,m);
	for(reg i=0;i<=n+m;++i)printf("%d ",F[i]);puts("");
}
int main(){return MyDearMoments(),0;}
posted @ 2023-01-05 09:07  Muel_imj  阅读(93)  评论(0编辑  收藏  举报