Luogu4245 【模板】任意模数NTT
https://www.luogu.com.cn/problem/P4245
三模数\(NTT\)
\(p^2 \max\{n,m\} \approx 10^{23}\),利用三个模数可以通过中国剩余定理唯一确定其值。
具体操作流程:
\[\begin{cases}
x \equiv x_1 \bmod p_1\\
x \equiv x_2 \bmod p_2\\
x \equiv x_3 \bmod p_3\\
\end{cases}
\\
M=p_1 p_2\\
A \equiv x_1 \times p_2 \times {p_2}^{-1}[\bmod p_1意义下]+x_2 \times p_1 \times {p_1}^{-1}[\bmod p_2意义下] \bmod M\\
\begin{cases}
x \equiv A \bmod M\\
x \equiv x_3 \bmod p_3\\
\end{cases}
\\
x_3 \equiv kM+A \bmod p_3\\
k \equiv (x_3-A)M^{-1} \bmod p_3\\
\]
第一步合并可以利用龟速乘,然后可以计算最终答案。
NTT模数
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 400005
#define ll long long
using namespace std;
int f[N],g[N];
int n,m,p;
struct ntt
{
/*
469762049=7*2^26+1 gi=3
998244353=119*2^23+1 gi=3
1004535809=479*2^21+1 gi=3
*/
int f[N],g[N];
int n,m,p,tw,gi;
int G[2][35];
int s,l,rev[N];
void Set(int cx,int cy,int x,int y,int Gi,int *a,int *b)
{
n=cx,m=cy,p=x,tw=y,gi=Gi;
memcpy(f,a,cx*sizeof(int));
memcpy(g,b,cy*sizeof(int));
}
void Add(int &x,int y)
{
x=(x+y)%p;
}
void Del(int &x,int y)
{
x=(x-y+p)%p;
}
void Mul(int &x,int y)
{
x=(ll)x*y%p;
}
int add(int x,int y)
{
return (x+y)%p;
}
int del(int x,int y)
{
return (x-y+p)%p;
}
int mul(int x,int y)
{
return (ll)x*y%p;
}
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y & 1)
Mul(ans,x);
Mul(x,x);
y >>=1;
}
return ans;
}
void Pre()
{
G[0][tw]=ksm(gi,(p-1)/(1 << tw));
G[1][tw]=ksm(G[0][tw],p-2);
for (int i=tw-1;i;--i)
{
G[0][i]=mul(G[0][i+1],G[0][i+1]);
G[1][i]=mul(G[1][i+1],G[1][i+1]);
}
}
void NTT(int *a,int t)
{
for (int i=0;i<s;++i)
if (i<rev[i])
swap(a[i],a[rev[i]]);
for (int mid=1,o=1;mid<s;mid <<=1,++o)
for (int j=0;j<s;j+=mid << 1)
{
int g=1;
for (int k=0;k<mid;++k,Mul(g,G[t][o]))
{
int x=a[j+k],y=mul(g,a[j+k+mid]);
a[j+k]=add(x,y),a[j+k+mid]=del(x,y);
}
}
}
void solve()
{
Pre();
s=1,l=0;
while (s<n+m)
s <<=1,++l;
for (int i=0;i<s;++i)
rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << l-1);
NTT(f,0),NTT(g,0);
for (int i=0;i<s;++i)
Mul(f[i],g[i]);
NTT(f,1);
int invs=ksm(s,p-2);
for (int i=0;i<s;++i)
Mul(f[i],invs);
}
}NTT1,NTT2,NTT3;
ll low_mul(ll x,ll y,ll p)
{
ll ans=0;
while (y)
{
if (y & 1)
ans=(ans+x)%p;
x=(x << 1) % p;
y >>=1;
}
return ans;
}
int main()
{
scanf("%d%d%d",&n,&m,&p),++n,++m;
for (int i=0;i<n;++i)
scanf("%d",&f[i]);
for (int i=0;i<m;++i)
scanf("%d",&g[i]);
int p1=469762049,p2=998244353,p3=1004535809;
NTT1.Set(n,m,p1,26,3,f,g);
NTT2.Set(n,m,p2,23,3,f,g);
NTT3.Set(n,m,p3,21,3,f,g);
NTT1.solve(),NTT2.solve(),NTT3.solve();
int inv1=NTT1.ksm(p2,p1-2),inv2=NTT2.ksm(p1,p2-2);
ll M=(ll)p1*p2;
int inv3=NTT3.ksm(M%p3,p3-2);
for (int i=0;i<n+m-1;++i)
{
ll x=(low_mul((ll)NTT1.f[i]*p2%M,inv1,M)+low_mul((ll)NTT2.f[i]*p1%M,inv2,M))%M;
ll k=(ll)inv3*(NTT3.f[i]-x%p3+p3)%p3;
int ans=(M%p*k%p+x%p+p)%p;
printf("%d ",ans);
}
putchar('\n');
return 0;
}
拆系数\(FFT\)
把大数字拆成最后\(15\)位(二进制下)和前面的位,直接\(FFT\),然后合并。
\[(2^{15} a_i+b_i)(2^{15}c_i+d_i)=\\
2^{30}a_i c_i+2^{15} (a_i d_i +b_i c_i)+b_i d_i
\]
计算一下就好了。
注意,本题对精度要求很高,尽量减少计算次数(一开始爆\(0\),还以为\(FFT\)炸了,傻了好久\(QAQ\))。
最好用\(long \quad double\),虽然\(double\)可以过,运行较快,但还是很慌的(极容易\(WA\))。
\(Code:\)
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#define N 400005
#define D long double
#define ll long long
using namespace std;
int n,m,p,q,l,s,rev[N];
const D Pi=acos(-1.0);
ll g1,g2,g3,ans;
struct virt
{
D x,y;
virt (D xx=0.0,D yy=0.0)
{
x=xx,y=yy;
}
virt operator + (virt b)
{
return virt(x+b.x,y+b.y);
}
virt operator - (virt b)
{
return virt(x-b.x,y-b.y);
}
virt operator * (virt b)
{
return virt(x*b.x-y*b.y,x*b.y+y*b.x);
}
}q1,q2,q3,q4,a[N],b[N],c[N],d[N],g[2][25],W[N];
void FFT(virt *a,int t)
{
for (int i=0;i<s;i++)
if (i<rev[i])
swap(a[i],a[rev[i]]);
for (int mid=1,o=1;mid<s;mid <<=1,o++)
for (int j=0;j<s;j+=(mid << 1))
for (int k=0;k<mid;k++)
{
virt x=a[j+k],y=W[mid+k]*a[j+k+mid];
a[j+k]=x+y;
a[j+k+mid]=x-y;
}
if (t==-1)
{
reverse(a+1,a+s);
D k=1.0/s;
for (int i=0;i<s;i++)
a[i]=a[i]*k;
}
}
int main()
{
scanf("%d%d%d",&n,&m,&p);
for (int i=0;i<=n;i++)
scanf("%d",&q),a[i].x=q >> 15,b[i].x=q & ((1 << 15) - 1);
for (int i=0;i<=m;i++)
scanf("%d",&q),c[i].x=q >> 15,d[i].x=q & ((1 << 15) - 1);
l=0,s=1;
while (s<=n+m)
{
s <<=1;
l++;
}
for (int i=0;i<s;i++)
rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
for(int i=1;i<s;i<<=1)
for (int k=0;k<i;k++)
W[i+k]=virt(cos(Pi*k/i),sin(Pi*k/i));
FFT(a,1);
FFT(b,1);
FFT(c,1);
FFT(d,1);
for (int i=0;i<s;i++)
{
q1=a[i],q2=b[i],q3=c[i],q4=d[i];
a[i]=q1*q3;
b[i]=q1*q4+q2*q3;
c[i]=q2*q4;
}
FFT(a,-1);
FFT(b,-1);
FFT(c,-1);
for (int i=0;i<=n+m;i++)
{
g1=(ll)(a[i].x+0.5);
g2=(ll)(b[i].x+0.5);
g3=(ll)(c[i].x+0.5);
g1=((g1%p) << 30)%p;
g2=((g2%p) << 15)%p;
g3=g3%p;
ans=((g1+g2+g3)%p+p)%p;
printf("%lld ",ans);
}
putchar('\n');
return 0;
}