【学习笔记】任意模数多项式乘法
三模数 NTT
由于多数 NTT 的操作对应值域 \(10^9\),规模 \(10^5\),所以选取三个常用 NTT 模数 \(p_1=998244353\)、\(p_2=1004535809\) 和 \(p_3=469702049\) 做三次乘法也就是九次 NTT。
三个模数的乘积大于结果的理论最大值,所以可以 CRT 合并得到原数再取模。使用 EXCRT 可以不开 __int128
。
EXCRT 具体过程是先把前两个结果 \(h_{1,i}\) 与 \(h_{2,i}\) 在 \(p_1\) 和 \(p_2\) 下合并,解得一个 \(k\),使得 \(h_i\equiv kp_1+h_{1,i}\pmod {p_1p_2}\),之后拿这个值去和 \(h_{3,i}\) 在 \(p_1p_2\) 和 \(p_3\) 下合并,解得一个 \(k'\) 使得 \(h_i\equiv k'p_1p_2+kp_1+h_{1,i}\),这个数对给定的 \(p\) 取模即可。
常数极大。
点击查看代码
inline int q_pow(int A,int B,int P){
int res=1;
while(B){
if(B&1) res=1ll*res*A%P;
A=1ll*A*A%P;
B>>=1;
}
return res;
}
inline ll exgcd(ll A,ll B,ll &X,ll &Y){
if(!B){
X=1,Y=0;
return A;
}
ll D=exgcd(B,A%B,Y,X);
Y-=A/B*X;
return D;
}
int rev[maxn];
int base,w[maxn];
struct Poly{
const static int g=3;
int deg;
vector<ull> f;
ull& operator[](const int &i){return f[i];}
ull operator[](const int &i)const{return f[i];}
inline void set(int L){deg=L;f.resize(L);}
inline void clear(int L,int R){for(int i=L;i<=R;++i)f[i]=0;}
inline void output(int L){for(int i=0;i<L;++i)printf("%llu ",f[i]);printf("\n");}
inline void NTT(int L,bool type,int P){
set(L);
int inv_g=q_pow(g,P-2,P);
for(int i=1;i<L;++i){
rev[i]=(rev[i>>1]>>1)+(i&1?L>>1:0);
if(i<rev[i]) swap(f[i],f[rev[i]]);
}
for(int d=1;d<L;d<<=1){
base=q_pow(type?g:inv_g,(P-1)/(2*d),P);
w[0]=1;
for(int i=1;i<d;++i) w[i]=1ll*w[i-1]*base%P;
for(int i=0;i<L;i+=d<<1){
for(int j=0;j<d;++j){
ull x=f[i+j],y=f[i+d+j]*w[j]%P;
f[i+j]=x+y,f[i+d+j]=x-y+P;
}
}
}
for(int i=0;i<L;++i) f[i]%=P;
if(!type){
int inv_L=q_pow(L,P-2,P);
for(int i=0;i<L;++i) f[i]=f[i]*inv_L%P;
}
}
}F,G,H[3];
int n,m,p;
int a[maxn],b[maxn],c[maxn];
ll mod[3]={998244353,1004535809,469762049};
inline int solve(ll A,ll B,ll C){
ll X1,Y1,X2,Y2;
exgcd(mod[0],mod[1],X1,Y1);
X1=(X1%mod[1]+mod[1])%mod[1];
X1=((B-A)%mod[1]+mod[1])%mod[1]*X1%mod[1];
exgcd(mod[0]*mod[1],mod[2],X2,Y2);
X2=(X2%mod[2]+mod[2])%mod[2];
X2=((C-(X1*mod[0]+A)%(mod[0]*mod[1]))%mod[2]+mod[2])%mod[2]*X2%mod[2];
return (X2%p*mod[0]%p*mod[1]%p+X1%p*mod[0]%p+A%p)%p;
}
int main(){
n=read(),m=read(),p=read();
for(int i=0;i<=n;++i) a[i]=read();
for(int i=0;i<=m;++i) b[i]=read();
int L=1;
while(L<n+m+1) L<<=1;
F.set(L),G.set(L);
for(int i=0;i<=2;++i){
H[i].set(L);
F.clear(0,L-1),G.clear(0,L-1);
for(int j=0;j<=n;++j) F[j]=a[j];
for(int j=0;j<=m;++j) G[j]=b[j];
F.NTT(L,1,mod[i]),G.NTT(L,1,mod[i]);
for(int j=0;j<L;++j) H[i][j]=F[j]*G[j]%mod[i];
H[i].NTT(L,0,mod[i]);
}
for(int i=0;i<=n+m;++i) c[i]=solve((ll)H[0][i],(ll)H[1][i],(ll)H[2][i]);
for(int i=0;i<=n+m;++i) printf("%d ",c[i]);
printf("\n");
return 0;
}
拆系数 FFT
通过把原多项式系数拆开来保证 FFT 的精度。
有一个好写且不掉精度的 \(5\) 次 FFT 做法。
取 \(B=\sqrt{p}\),把两个多项式拆成:\(F(x)=B\times F_1(x)+F_2(x),G(x)=B\times G_1(x)+G_2(x)\)。
这样卷积的结果是 \(H(x)=B^2\times F_1(x)G_1(x)+B\times (F_1(x)G_2(x)+F_2(x)G_1(x))+F_2(x)G_2(x)\)。
设 \(T(x)=G_1(x)+i\times G_2(x)\),那么 \(F_1(x)T(x)\) 和 \(F_2(x)T(x)\) 的实部虚部就分别对应上面的四个卷积结果。
这样只需要对 \(F_1(x),F_2(x),T(x)\) 做 FFT,对 \(F_1(x)T(x)\) 和 \(F_2(x)T(x)\) 做 IFFT,\(5\) 次就可以了。
为了保证精度可以预处理单位根配合 long double
。
点击查看代码
int rev[maxn];
struct Complex{
db a,b;
Complex()=default;
Complex(db a_,db b_):a(a_),b(b_){}
Complex operator+(const Complex &rhs)const{return Complex(a+rhs.a,b+rhs.b);}
Complex operator-(const Complex &rhs)const{return Complex(a-rhs.a,b-rhs.b);}
Complex operator*(const Complex &rhs)const{return Complex(a*rhs.a-b*rhs.b,a*rhs.b+b*rhs.a);}
}base,W[maxn],w[maxn];
struct Poly{
int deg;
vector<Complex> f;
Complex& operator[](const int &i){return f[i];}
Complex operator[](const int &i)const{return f[i];}
inline void set(int L){deg=L;f.resize(L);}
inline void clear(int L,int R){for(int i=L;i<=R;++i)f[i]=Complex(0,0);}
inline void output(int L){for(int i=0;i<L;++i)printf("(%Lf,%Lf) ",f[i].a,f[i].b);printf("\n");}
inline void FFT(int L,bool type){
set(L);
for(int i=1;i<L;++i){
rev[i]=(rev[i>>1]>>1)+(i&1?L>>1:0);
if(i<rev[i]) swap(f[i],f[rev[i]]);
}
for(int d=1;d<L;d<<=1){
for(int i=0,j=0;i<d;++i,j+=L/(2*d)) w[i]=W[type?j:L-j];
for(int i=0;i<L;i+=d<<1){
for(int j=0;j<d;++j){
Complex x=f[i+j],y=w[j]*f[i+d+j];
f[i+j]=x+y,f[i+d+j]=x-y;
}
}
}
if(!type){
for(int i=0;i<L;++i) f[i].a/=L,f[i].b/=L;
}
}
}A,B,T,F,G;
int n,m,p,C;
int main(){
n=read(),m=read(),p=read(),C=sqrt(p);
int L=1;
while(L<n+m+1) L<<=1;
for(int i=0;i<=L;++i) W[i]=Complex(cos(i*2*pi/L),sin(i*2*pi/L));
A.set(L),B.set(L),T.set(L),F.set(L),G.set(L);
for(int i=0;i<=n;++i){
int x=read()%p;
A[i]=Complex(1.0*(x/C),0),B[i]=Complex(1.0*(x%C),0);
}
for(int i=0;i<=m;++i){
int x=read()%p;
T[i]=Complex(1.0*(x/C),1.0*(x%C));
}
A.FFT(L,1),B.FFT(L,1),T.FFT(L,1);
for(int i=0;i<L;++i) F[i]=A[i]*T[i],G[i]=B[i]*T[i];
F.FFT(L,0),G.FFT(L,0);
for(int i=0;i<=n+m;++i){
int now=0;
now=(now+1ll*C*C%p*((ll)(F[i].a+0.5)%p)%p)%p;
now=(now+1ll*C*((ll)(F[i].b+0.5)%p+(ll)(G[i].a+0.5)%p)%p+p)%p;
now=(now+(ll)(G[i].b+0.5)%p+p)%p;
printf("%d ",now);
}
printf("\n");
return 0;
}