[模板] 任意模数多项式乘法
一、题目
二、解法
任意模数 \(\tt NTT\) 就是找三个常见的大模数,然后用中国剩余定理合并,建议用下面的模数:
\[998244353,1004535809,469762049
\]
假设求出了三个答案是 \(x_1,x_2,x_3\) ,由于模数是质数我们的合并时可以用逆元的:
\[x=x_1\mod A
\]
\[x=x_2\mod B
\]
\[x=x_3\mod C
\]
直接把一式带进二式中:
\[x_1+k_1A=x_2\mod B
\]
\[k_1=\frac{x_2-x_1}{A}\mod B
\]
那么新的 \(x=x_1+k_1A\mod AB\) ,记 \(x_4=x_1+k_1A\) ,把他带进第三个柿子:
\[x_4+k_4AB=x_3\mod C
\]
\[k_4=\frac{x_3-x_4}{AB}\mod C
\]
所以真实的 \(x=x_4+k_4AB\mod ABC\) ,因为答案 \(<ABC\) ,所以就对了。要注意由于答案是正的,我们得到的 \(x\) 也要是正的,不能再模的过程中产生负数了,我的常数大是个未解之谜。
#include <cstdio>
#include <iostream>
using namespace std;
const int M = 300005;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;c=getchar();}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,p,lg,len,a[M],b[M],c[M][3],A[M],B[M],rev[M];
int md[3],MOD,inv1,inv2;
int qkpow(int a,int b,const int mod=MOD)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%mod;
a=a*a%mod;
b>>=1;
}
return r;
}
void NTT(int *a,const int len,int tmp)
{
const int mod = MOD;
for(int i=0;i<len;i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int s=2;s<=len;s<<=1)
{
int t=s/2,w=(tmp==1)?qkpow(3,(mod-1)/s):qkpow(3,(mod-1)-(mod-1)/s);
for(int i=0;i<len;i+=s)
{
int x=1;
for(int j=0;j<t;j++,x=x*w%mod)
{
int fe=a[i+j],fo=a[i+j+t];
a[i+j]=(fe+x*fo)%mod;
a[i+j+t]=((fe-fo*x)%mod+mod)%mod;
}
}
}
if(tmp==1) return ;
int inv=qkpow(len,mod-2);
for(int i=0;i<len;i++) a[i]=a[i]*inv%mod;
}
void zy(int t)
{
MOD=md[t];
lg=0;len=1;
while(len<=n+m+1) len<<=1,lg++;
for(int i=0;i<len;i++) A[i]=B[i]=0;//一定要清0
for(int i=0;i<=n;i++) A[i]=a[i];
for(int i=0;i<=m;i++) B[i]=b[i];
NTT(A,len,1);NTT(B,len,1);
for(int i=0;i<len;i++) A[i]=A[i]*B[i]%MOD;
NTT(A,len,-1);
for(int i=0;i<=n+m;i++) c[i][t]=A[i];
}
int jzm(int w)
{
int x1=c[w][0],x2=c[w][1],x3=c[w][2];
int k1=((x2-x1)*inv1%md[1]+md[1])%md[1];
int x4=(x1+k1*md[0]);//这里不能直接模C,因为他是在模AB意义下的
int k4=(x3-x4%md[2]+md[2])%md[2]*inv2%md[2];
//一定要一直模成正数,因为换模数的时候一定要是正数才是真实数
return (x4+k4*md[0]%p*md[1]%p)%p;
}
signed main()
{
md[0]=998244353;
md[1]=1004535809;
md[2]=469762049;
inv1=qkpow(md[0],md[1]-2,md[1]);
inv2=qkpow(md[0]*md[1]%md[2],md[2]-2,md[2]);
n=read();m=read();p=read();
for(int i=0;i<=n;i++) a[i]=read();
for(int i=0;i<=m;i++) b[i]=read();
zy(0);zy(1);zy(2);
for(int i=0;i<=n+m;i++)
printf("%lld ",(jzm(i)+p)%p);
}