多项式模板

\(1.1\) \(\text{FFT}\)&&预处理单位复数根

别用 \(\text{STL}\) 自带复数......手写一个也不要多久。直接上模板。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=4000010; const double pi=acos(-1.0);
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,T,rev[N];
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
struct Complex
{
    double x,y;
    inline Complex (double pp=0, double qq=0) { x=pp, y=qq; }
}a[N],b[N];
inline Complex operator + (Complex a,Complex b) { return Complex(a.x+b.x,a.y+b.y); }
inline Complex operator - (Complex a,Complex b) { return Complex(a.x-b.x,a.y-b.y); }
inline Complex operator * (Complex a,Complex b) { return Complex(a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y); }
inline void FFT(Complex *s,int type)
{
    for(ri int i=0;i<N;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int mid=1;mid<T;mid<<=1)
    {
        Complex wn(cos(pi/mid),type*sin(pi/mid));
        for(ri int r=mid<<1, j=0;j<T;j+=r)
        {
            Complex w(1,0);
            for(ri int k=0;k<mid;k++, w=w*wn)
            {
                Complex x=s[j+k], y=w*s[j+mid+k];
                s[j+k]=x+y, s[j+mid+k]=x-y;
            }
        }
    }
}
signed main()
{
    n=read(), m=read();
    for(ri int i=0;i<=n;i++) a[i].x=read();
    for(ri int i=0;i<=m;i++) b[i].x=read();
    T=1; while(T<=(n+m)) T<<=1;
    Get_Rev();
    FFT(a,1), FFT(b,1);
    for(ri int i=0;i<T;i++) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(ri int i=0;i<=n+m;i++) printf("%d ",(int)(a[i].x/T+0.5));
    puts("");
    return 0;
}

系数比较大的时候精度会不够用......考虑预处理单位复数根,然后 \(\text{FFT}\) 时如果为 \(\text{IDFT}\) 显然要把虚部取反。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=4000010; const double pi=acos(-1.0);
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,rev[N],T;
struct Complex
{
    double x,y;
    inline Complex (double pp=0, double qq=0) { x=pp, y=qq; }
}a[N],b[N],w[N];
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline Complex operator + (Complex a,Complex b) { return Complex(a.x+b.x,a.y+b.y); };
inline Complex operator - (Complex a,Complex b) { return Complex(a.x-b.x,a.y-b.y); }
inline Complex operator * (Complex a,Complex b) { return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); }
inline void FFT(Complex *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int mid=1;mid<T;mid<<=1)
    {
        for(ri int j=0, r=mid<<1;j<T;j+=r)
        {
            for(ri int k=0;k<mid;k++)
            {
                Complex qwq=w[T/mid*k];
                if(type==-1) qwq.y=-qwq.y;
                Complex x=s[j+k], y=qwq*s[j+mid+k];
                s[j+k]=x+y, s[j+mid+k]=x-y;
            }
        }
    }
}
signed main()
{
    n=read(), m=read();
    for(ri int i=0;i<=n;i++) a[i].x=read();
    for(ri int i=0;i<=m;i++) b[i].x=read();
    T=1; while(T<=(n+m)) T<<=1;
    Get_Rev();
    for(ri int i=0;i<T;i++) w[i]=Complex(cos(pi/T*i),sin(pi/T*i));
    FFT(a,1), FFT(b,1);
    for(ri int i=0;i<T;i++) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(ri int i=0;i<=n+m;i++) printf("%d ",(int)(a[i].x/T+0.5));
    puts("");
    return 0;
}

\(1.2\) \(\text{NTT}\)&&预处理原根幂

直接上模板。只适用于模数为 \(\text{NTT}\) 模数的情况。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=4000010, Mod=998244353;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,a[N],b[N],T=1,rev[N]; char s[N];
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void NTT(int *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2,cnt=1;i<=T;cnt++, i<<=1)
    {
        int wn=ksc(3,(Mod-1)/i);
        if(!type) wn=ksc(wn,Mod-2);
        for(ri int j=0, mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0, w=1;k<mid;k++, w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod, s[j+mid+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
signed main()
{
    n=read(), m=read();
    for(ri int i=0;i<=n;i++) a[i]=read();
    for(ri int i=0;i<=m;i++) b[i]=read();
    T=1; while(T<=(n+m)) T<<=1;
    Get_Rev();
    NTT(a,1), NTT(b,1);
    for(ri int i=0;i<T;i++) a[i]=1ll*a[i]*b[i]%Mod;
    NTT(a,0);
    for(ri int i=0;i<=n+m;i++) printf("%d ",a[i]);
    puts("");
    return 0;
}

考虑多次 \(\text{NTT}\) 时可以预处理原根幂。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=4000010, Mod=998244353;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,a[N],b[N],T=1,rev[N],f[25][2]; char s[N];
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void NTT(int *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2,cnt=1;i<=T;cnt++, i<<=1)
    {
        int wn=f[cnt][type];
        for(ri int j=0, mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0, w=1;k<mid;k++, w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod, s[j+mid+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
signed main()
{
    f[23][1]=ksc(3,119), f[23][0]=ksc(332748118,119);
    for(ri int i=22;~i;i--) f[i][1]=1ll*f[i+1][1]*f[i+1][1]%Mod, f[i][0]=1ll*f[i+1][0]*f[i+1][0]%Mod;
    n=read(), m=read();
    for(ri int i=0;i<=n;i++) a[i]=read();
    for(ri int i=0;i<=m;i++) b[i]=read();
    T=1; while(T<=(n+m)) T<<=1;
    Get_Rev();
    NTT(a,1), NTT(b,1);
    for(ri int i=0;i<T;i++) a[i]=1ll*a[i]*b[i]%Mod;
    NTT(a,0);
    for(ri int i=0;i<=n+m;i++) printf("%d ",a[i]);
    puts("");
    return 0;
}

\(1.3\) \(\text{MTT}\)&&拆系数\(\text{FFT}\)

考虑到数据范围较大,例如,当两个长度为 \(10^{5}\) 级别的序列卷积且模数为 \(10^{9}\) 级别(不为 \(\text{NTT}\) 模数)。肯定不能直接 \(\text{NTT}\) 了。直接 \(\text{FFT}\) 的话,每个数的结果大约为 \(10^{5}\times 10^{9} \times 10^{9}=10^{23}\),超出了 \(2^{64}\),浮点数会出现较大的误差。故可以考虑使用拆系数 \(\text{FFT}\) 和三模数 \(\text{NTT}\)\(2.3\) 中将简略讲明拆系数 \(\text{FFT}\)

\(P\) 为模数大小。把 \(A(x)\)\(B(x)\) 的每一项分别拆成 \(aQ+b\)。记 \(A(x)=A1(x)Q+A2(x)\)\(B(x)=B1(x)Q+B2(x)\)。则显然有:

\[\qquad A(x)*B(x)=A1(x)B1(x)Q^{2}+A1(x)B2(x)Q+A2(x)B1(x)Q+A2(x)B2(x) \qquad \]

\(4\)\(\text{DFT}\)\(4\)\(\text{IDFT}\) 即可解决。考虑令 \(Q=\sqrt{x}\),则卷积结果大约在 \(10^{14}\)。发现此时需要预处理单位复数根。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=400010; const double pi=acos(-1.0);
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,P,T,rev[N],a[N],b[N],blk,ans[N];
struct Complex
{
    double x,y;
    inline Complex ( double pp=0, double qq=0) { x=pp, y=qq; }
}f1[N],f2[N],g1[N],g2[N],w[N],d1[N],d2[N],d3[N],d4[N];
inline Complex operator + (Complex a,Complex b) { return Complex(a.x+b.x,a.y+b.y); }
inline Complex operator - (Complex a,Complex b) { return Complex(a.x-b.x,a.y-b.y); }
inline Complex operator * (Complex a,Complex b) { return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); }
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void Init()
{
    T=1;
    while(T<=n+m) T<<=1;
    Get_Rev();
    for(ri int i=0;i<T;i++) w[i]=Complex(cos(pi/T*i),sin(pi/T*i));
}
inline void FFT(Complex *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int mid=1;mid<T;mid<<=1)
    {
        for(ri int j=0, r=mid<<1;j<T;j+=r)
        {
            for(ri int k=0;k<mid;k++)
            {
                Complex qwq=w[T/mid*k];
                if(type==-1) qwq.y=-qwq.y;
                Complex x=s[j+k], y=qwq*s[j+mid+k];
                s[j+k]=x+y, s[j+mid+k]=x-y;
            }
        }
    }
}
signed main()
{
    n=read(), m=read(), P=read();
    Init();
    for(ri int i=0;i<=n;i++) a[i]=read();
    for(ri int i=0;i<=m;i++) b[i]=read();
    blk=32768;
    for(ri int i=0;i<=n;i++) f1[i]=Complex(a[i]/blk,0), f2[i]=Complex(a[i]%blk,0);
    for(ri int i=0;i<=m;i++) g1[i]=Complex(b[i]/blk,0), g2[i]=Complex(b[i]%blk,0);
    FFT(f1,1), FFT(f2,1), FFT(g1,1), FFT(g2,1);
    for(ri int i=0;i<T;i++) d1[i]=f1[i]*g1[i]; FFT(d1,-1);
    for(ri int i=0;i<T;i++) d2[i]=f1[i]*g2[i]; FFT(d2,-1);
    for(ri int i=0;i<T;i++) d3[i]=f2[i]*g1[i]; FFT(d3,-1);
    for(ri int i=0;i<T;i++) d4[i]=f2[i]*g2[i]; FFT(d4,-1);
    for(ri int i=0;i<T;i++)
    {
        ans[i]=
        (1ll*(long long)(d1[i].x/T+0.5)%P*blk%P*blk%P+
        1ll*(long long)(d2[i].x/T+0.5)%P*blk%P)%P;
        ans[i]=(ans[i]+1ll*(long long)(d3[i].x/T+0.5)%P*blk%P)%P;
        ans[i]=(ans[i]+1ll*(long long)(d4[i].x/T+0.5)%P)%P;
    }
    for(ri int i=0;i<=n+m;i++) printf("%d ",ans[i]);
    puts("");
    return 0;
}

\(1.4\) \(\text{MTT}\)&&三模数\(\text{NTT}\)

\(2.4\) 中将简略讲明三模数 \(\text{NTT}\)

考虑找 \(3\) 个大小在 \(10^{9}\) 级别的 \(\text{NTT}\) 模数,如 \(\{469762049,998244353,1004535809\}\),它们的原根都是 \(3\),实际实现起来也非常方便,并且它们的乘积比 \(10^{23}\) 要大。

考虑分别在 \(3\) 个模数意义下求卷积结果,然后中国剩余定理合并。但是发现 \(p_{1}\times p_{2}\times p_{3}>10^{23}\),显然它们会爆 \(long\) \(long\)。这个问题可以用 __\(int128\) 或者手写高精解决。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
#define int __int128
using namespace std; const int N=400010, Mod1=469762049, Mod2=998244353, Mod3=1004535809;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,P,rev[N],T,a[4][N],b[4][N],f[N];
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline int ksc(int x,int p,int Mod) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void NTT(int *s,int type,int Mod)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2;i<=T;i<<=1)
    {
        int wn=ksc(3,(Mod-1)/i,Mod);
        if(!type) wn=ksc(wn,Mod-2,Mod);
        for(ri int j=0, mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0, w=1;k<mid;k++, w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod, s[j+mid+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if(!type) for(ri int i=0, inv=ksc(T,Mod-2,Mod);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
int exgcd(int a,int b,int &x,int &y)
{
    if(!b) { x=1, y=0; return a; }
    int tt=exgcd(b,a%b,y,x);
    y-=a/b*x; return tt;
}
signed main()
{
    n=read(), m=read(), P=read();
    for(ri int i=0;i<=n;i++) a[1][i]=a[2][i]=a[3][i]=read();
    for(ri int i=0;i<=m;i++) b[1][i]=b[2][i]=b[3][i]=read();
    T=1; while(T<=n+m) T<<=1; Get_Rev();
    NTT(a[1],1,Mod1), NTT(b[1],1,Mod1);
    for(ri int i=0;i<T;i++) a[1][i]=1ll*a[1][i]*b[1][i]%Mod1;
    NTT(a[1],0,Mod1);
    NTT(a[2],1,Mod2), NTT(b[2],1,Mod2);
    for(ri int i=0;i<T;i++) a[2][i]=1ll*a[2][i]*b[2][i]%Mod2;
    NTT(a[2],0,Mod2);
    NTT(a[3],1,Mod3), NTT(b[3],1,Mod3);
    for(ri int i=0;i<T;i++) a[3][i]=1ll*a[3][i]*b[3][i]%Mod3;
    NTT(a[3],0,Mod3);
    //for(ri int i=0;i<=n+m;i++) printf("%d %d %d\n",a[1][i],a[2][i],a[3][i]);
    for(ri int i=0;i<T;i++)
    {
        int gt=Mod1, res=a[1][i];
        for(ri int j=2;j<=3;j++)
        {
            int md=(j==2)?Mod2:Mod3;
            int dd=(a[j][i]-res%md+md)%md;
            int x,y; x=y=0;
            int c=exgcd(gt,md,x,y);
            x=1ll*x*(dd/c)%(md/c);
            res=1ll*(res+1ll*x*gt);
            gt=1ll*gt*md/c;
            res=(res%gt+gt)%gt;
        }
        f[i]=(res%gt+gt)%gt;
        f[i]%=P;
    }
    for(ri int i=0;i<=n+m;i++) print(f[i]), putchar(' ');
    puts("");
    return 0;
}

但是,__\(int128\) 的常数非常大,而且在实际比赛时也一般无法使用。考虑有什么更高明的做法。

考虑先只合并前两个模数。设最终答案为 \(s\),且有

\[\qquad s\equiv s_{1}\pmod {p_{1}} \qquad \]

\[\qquad s\equiv s_{2}\pmod {p_{2}} \qquad \]

\[\qquad s\equiv s_{3}\pmod {p_{3}} \qquad \]

前两项显然可以在 \(long\) \(long\) 范围内合并,记

\[\qquad s\equiv s_{4}\pmod {p_{1}p_{2}} \qquad \]

存在 \(s=s_{4}+w_{1}p_{1}p_{2}=s_{3}+w_{2}p_{3}\)

显然有 \(w_{1}<p_{3}\),则有

\[\qquad w_{1}\equiv (s_{3}-s_{4})(p_{1}p_{2})^{-1}\pmod {p_{3}} \qquad \]

那么可以求出 \(w_{1}\) 的真实值,那么通过上式,即可在模 \(P\) 意义下求出 \(s\) 的值。但是常数非常大。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int N=400010, Mod1=469762049, Mod2=998244353, Mod3=1004535809;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,P,rev[N],T,a[4][N],b[4][N];
long long f[N];
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline int ksc(int x,int p,int Mod) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void NTT(int *s,int type,int Mod)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2;i<=T;i<<=1)
    {
        int wn=ksc(3,(Mod-1)/i,Mod);
        if(!type) wn=ksc(wn,Mod-2,Mod);
        for(ri int j=0, mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0, w=1;k<mid;k++, w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod, s[j+mid+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if(!type) for(ri int i=0, inv=ksc(T,Mod-2,Mod);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
long long exgcd(long long a,long long b,long long &x,long long &y)
{
    if(!b) { x=1, y=0; return a; }
    int tt=exgcd(b,a%b,y,x);
    y-=a/b*x; return tt;
}
signed main()
{
    n=read(), m=read(), P=read();
    for(ri int i=0;i<=n;i++) a[1][i]=a[2][i]=a[3][i]=read();
    for(ri int i=0;i<=m;i++) b[1][i]=b[2][i]=b[3][i]=read();
    T=1; while(T<=n+m) T<<=1; Get_Rev();
    NTT(a[1],1,Mod1), NTT(b[1],1,Mod1);
    for(ri int i=0;i<T;i++) a[1][i]=1ll*a[1][i]*b[1][i]%Mod1;
    NTT(a[1],0,Mod1);
    NTT(a[2],1,Mod2), NTT(b[2],1,Mod2);
    for(ri int i=0;i<T;i++) a[2][i]=1ll*a[2][i]*b[2][i]%Mod2;
    NTT(a[2],0,Mod2);
    NTT(a[3],1,Mod3), NTT(b[3],1,Mod3);
    for(ri int i=0;i<T;i++) a[3][i]=1ll*a[3][i]*b[3][i]%Mod3;
    NTT(a[3],0,Mod3);
    //for(ri int i=0;i<=n+m;i++) printf("%d %d %d\n",a[1][i],a[2][i],a[3][i]);
    for(ri int i=0;i<T;i++)
    {
        long long gt=Mod1; long long res=a[1][i];
        long long md=Mod2, qwq=1ll*Mod1*Mod2;
        long long dd=(a[2][i]-res%md+md)%md;
        long long x,y; x=y=0;
        long long c=exgcd(gt,md,x,y);
        x=x*(dd/c)%(md/c);
        res=(res+x*gt);
        gt=gt*(md/c);
        res=(res%gt+gt)%gt;
        long long yy=a[3][i];
        long long k=((yy-res)%Mod3+Mod3)%Mod3*ksc(qwq%Mod3,Mod3-2,Mod3)%Mod3;
        f[i]=(res+k%P*Mod1%P*Mod2%P)%P;
    }
    for(ri int i=0;i<=n+m;i++) print(f[i]), putchar(' ');
    puts("");
    return 0;
}

\(1.5\) 分治 \(\text{FFT}\)

引题:已知 \(f_{0}=1\) 以及给出序列 \(g_{1...n-1}\),求序列 \(f_{0...n-1}\)。其中 \(f_{i}=\sum\limits_{j=1}^{i}f_{i-j}g_{j}\)。答案对 \(998244353\) 取模。

类似于上式中,我们发现它类似于卷积形式,但是求出后面的数字基于求出了前面的式子,无法直接进行 \(\text{FFT}\)\(\text{NTT}\),时间复杂度会退化到 \(O(n^{2})\)。此时考虑使用分治 \(\text{FFT}\)

分治 \(\text{FFT}\) 的核心在于利用分治后利用 \(mid\) 将区间 \([L,R]\) 分为两个互不相交的子区间,然后就可以解决题目的某些限制。在上式中,我们考虑把区间分为 \([L,mid)\)\([mid,R)\)。发现当 \(R-L=1\) 时结束分治。否则,假设已经求出 \([L,mid)\)\(f_{i}\),则它们对 \([mid,R)\) 的贡献显然为:

\[\qquad F_{i}=\sum\limits_{j=L}^{mid-1} f_{j}\times g_{i-j} \qquad \]

这个式子就是卷积的基本形式,直接用 \(\text{FFT}\)\(\text{NTT}\) 求出然后加到对应的数组中,再加到 \(f_{i}\) 上即可。所以本题分治 \(\text{FFT}\) 的做法是:递归左区间,计算左区间对右区间的贡献,递归右区间。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
//#define int long long
//#define double long double
using namespace std; const int MAXN=400010, Mod=998244353;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,f[MAXN],g[MAXN],rev[MAXN],a[MAXN],b[MAXN],r[25][2],T;
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Init()
{
    r[23][1]=ksc(3,119), r[23][0]=ksc(332748118,119);
    for(ri int i=22;~i;i--) r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod, r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod;
}
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void NTT(int *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2, cnt=1;i<=T;cnt++, i<<=1)
    {
        int wn=r[cnt][type];
        for(ri int j=0, mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0, w=1;k<mid;k++, w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod, s[j+mid+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if(!type) for(ri int i=0, inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
void DivDFT(int l,int r)
{
    if(l+1>=r) return;
    int mid=((l+r)>>1);
    DivDFT(l,mid);
    T=r-l; Get_Rev(r-l);
    for(ri int i=0;i<mid-l;i++) a[i]=f[i+l];
    for(ri int i=mid-l;i<r-l;i++) a[i]=0;
    for(ri int i=0;i<r-l;i++) b[i]=g[i];
    NTT(a,1), NTT(b,1);
    for(ri int i=0;i<T;i++) a[i]=1ll*a[i]*b[i]%Mod;
    NTT(a,0);
    for(ri int i=mid;i<r;i++) f[i]=(f[i]+a[i-l])%Mod;
    DivDFT(mid,r);
}
signed main()
{
    n=read(); Init(); f[0]=1;
    for(ri int i=1;i<n;i++) g[i]=read();
    T=1;
    while(T<n) T<<=1;
    DivDFT(0,T);
    for(ri int i=0;i<n;i++) printf("%d ",f[i]);
    puts("");
    return 0;
}

例题:题目链接

考虑对于两种运算 \(x+y\)\(x-y\) 分别操作。

先不考虑限制,只考虑如何快速计算出 \(x+y\)\(x-y\)。设 \(f_{i}=\sum\limits_{p=1}^{n}[a_{p}=i]\)\(g_{i}=\sum\limits_{p=1}^{m}[b_{p}=i]\)。计算 \(x+y\) 相同的,就显然有 \(F_{i}=\sum\limits_{p+q=i}f_{p}g_{q}=\sum\limits_{p=0}^{i}f_{p}g_{i-p}\)。这个问题我们可以用 \(\text{FFT}\)\(\text{NTT}\)\(O(n\log n)\) 的时间复杂度内解决。对于 \(x-y\),可以看作是 \(x+(-y)\)。所以可以直接翻转 \(b\) 数组,这样就相当于完成了取相反数的操作。

于是考虑对于 \(x,y\) 的限制。发现分治可以满足这一条件。对于当前区间 \([L,R]\),如果 \(L=R\) 就直接返回

\(f_{L}\times g_{R}\)。否则我们把区间分为 \([L,mid]\)\((mid+1,R]\)。对于 \(x\) 在区间左侧,\(y\) 在区间右侧的情况,即计算 \(x+y\)。对于 \(x\) 在区间右侧,\(y\) 在区间左侧的情况,即计算 \(x-y\)。然后继续分治。其思想类似于分治 \(\text{FFT}\)

#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <set>
#include <vector>
#include <stack>
#include <map>
#include <bitset>
#define ri register
#define inf 0x7fffffff
#define E (1)
#define mk make_pair
#define int long long
//#define double long double
using namespace std; const int N=400010; const double pi=acos(-1.0);
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch-'0'), ch=getchar();
    return s*w;
}
void print(int x) { if(x<0) x=-x, putchar('-'); if(x>9) print(x/10); putchar(x%10+'0'); }
int n,m,Q,rev[N],T,p[N],q[N],ans[N];
struct Complex
{
    double x,y;
    inline Complex (double pp=0, double qq=0 ) { x=pp, y=qq; }
}a[N],b[N];
inline Complex operator + (Complex a,Complex b) { return Complex(a.x+b.x,a.y+b.y); }
inline Complex operator - (Complex a,Complex b) { return Complex(a.x-b.x,a.y-b.y); }
inline Complex operator * (Complex a,Complex b) { return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); }
inline void Get_Rev() { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void FFT(Complex *s,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int mid=1;mid<T;mid<<=1)
    {
        Complex wn(cos(pi/mid),type*sin(pi/mid));
        for(ri int j=0, r=mid<<1;j<T;j+=r)
        {
            Complex w(1,0);
            for(ri int k=0;k<mid;k++, w=w*wn)
            {
                Complex x=s[j+k], y=w*s[j+mid+k];
                s[j+k]=x+y, s[j+mid+k]=x-y;
            }
        }
    }
}
#define mid (L+R>>1)
void DivideWithFFT(int L,int R)
{
    if(L==R) { ans[0]+=p[L]*q[R]; return; }
    T=1; while(T<=R-L) T<<=1; Get_Rev();
    for(ri int i=0;i<T;i++) a[i]=b[i]=Complex(0,0);
    for(ri int i=L;i<=mid;i++) a[i-L]=Complex(p[i],0);
    for(ri int i=mid+1;i<=R;i++) b[i-mid-1]=Complex(q[i],0);
    FFT(a,1), FFT(b,1);
    for(ri int i=0;i<T;i++) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(ri int i=0;i<T;i++) ans[i+L+mid+1]+=(int)(a[i].x/T+0.5);
    for(ri int i=0;i<T;i++) a[i]=b[i]=Complex(0,0);
    for(ri int i=mid+1;i<=R;i++) a[i-mid-1]=Complex(p[i],0);
    for(ri int i=L;i<=mid;i++) b[mid-i]=Complex(q[i],0);
    FFT(a,1), FFT(b,1);
    for(ri int i=0;i<T;i++) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(ri int i=0;i<T;i++) ans[i+1]+=(int)(a[i].x/T+0.5);
    DivideWithFFT(L,mid), DivideWithFFT(mid+1,R);
}
#undef mid
signed main()
{
    for(ri int cas=read();cas;cas--)
    {
        n=read(), m=read(), Q=read();
        for(ri int i=1;i<=n;i++)
        {
            int x=read();
            p[x]++;
        }
        for(ri int i=1;i<=m;i++)
        {
            int x=read();
            q[x]++;
        }
        DivideWithFFT(0,50000);
        for(;Q;Q--)
        {
            int x=read();
            printf("%lld\n",ans[x]);
        }
        for(ri int i=0;i<=100000;i++) p[i]=q[i]=ans[i]=0;
    }
    return 0;
}

posted @ 2020-08-07 18:34  zkdxl  阅读(139)  评论(7编辑  收藏  举报