多项式

不知不觉,就开始学习多项式了,好快啊

重新拾起我之前学的百嘛不是的\(FFT\),新学\(NTT/FWT\)

开始吧学习博客

FWT博客

FFT:快速傅里叶变换

总的来说就是快速求取两个多项式的乘积

把板子粘在这里Luogu3803 多项式乘法

code
#include<bits/stdc++.h>
using namespace std;
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
const double pi=acos(-1.0);
const int N=1<<21;
struct coex{
    double r,i;
    coex(){}
    coex(double x,double y){r=x;i=y;}
    coex operator + (coex a){return coex(r+a.r,i+a.i);}
    coex operator - (coex a){return coex(r-a.r,i-a.i);}
    coex operator * (coex a){return coex(r*a.r-i*a.i,r*a.i+i*a.r);}
}a[N],b[N],w[N];
int af[N],la,lb,lim,len;
void fft(coex *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                coex tmp=w[t*j]*a[i+j+d];
                a[i+j+d]=a[i+j]-tmp;
                a[i+j]=a[i+j]+tmp;
            }
}
signed main(){
    scanf("%d%d",&la,&lb);
    fo(i,0,la)scanf("%lf",&a[i].r);
    fo(i,0,lb)scanf("%lf",&b[i].r);
    for(lim=1,len=0;lim<=la+lb;lim<<=1,len++);
    fo(i,0,lim-1){
        af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
        w[i]=coex(cos(2.0*i*pi/lim),sin(2.0*i*pi/lim));
    }
    fft(a,lim);fft(b,lim);
    fo(i,0,lim-1)a[i]=a[i]*b[i],w[i].i=-w[i].i;
    fft(a,lim);
    fo(i,0,la+lb)printf("%d ",(int)(a[i].r/lim+0.5));
}

需要注意这么几点:

1、\(lim\)要大于多项式次数,严格大于

2、搞清楚\(d\)\(t\)的关系,以及各种2倍

3、背板子,背就完事了。

4、最后答案别忘了除以\(lim\),预处理单位根别忘了\(2*i*pi/lim\),别忘了除以\(lim\)

NTT:快速数论变换

code
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=g[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int bas[M],ans[M];
int a[M],b[M];
void mul(int *x,int *y,int *z){
    fo(i,0,lim-1)a[i]=y[i];
    fo(i,0,lim-1)b[i]=z[i];
    g[0]=1;g[1]=ksm(3,(mod-1)/lim,mod);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    ntt(a,lim);ntt(b,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2,mod);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)a[i]=a[i]*b[i]%mod;
    ntt(a,lim);int inv=ksm(lim,mod-2,mod);
    fo(i,0,lim-1)a[i]=a[i]*inv%mod;
    fo(i,0,m-2)a[i]=(a[i]+a[i+m-1])%mod;
    fo(i,0,m-2)x[i]=a[i];
}

注意:

1、注意模数,以及自己找原根

2、注意最后出来的时候除以\(lim\)

3、别忘了变换原根

FWT:快速沃尔什变换

发现我们所做的卷积都是\(c=\sum_{i=1}^{n}\sum_{j=1}^{i}a_j*b_{i-j}\)

也就是\(FFT\)\(NTT\)能处理的加法卷积,如果是与?或?异或呢?

就要用到\(FWT\),也就是快速位运算卷积

|:或卷积

这个要是直接卷的话,好像有点困难,但是高维前缀和就是或卷积,所以我们可以用高维前缀和解决或卷积的问题

然而高维前缀和还是没有\(FWT\)来的直接

\(FWT\)也和\(FFT\)一样有变换和逆变换,这个直接看最上面我给出的那个博客吧

这次变换不能叫点值表达式了,应该叫它子集和表达式

相当于是枚举每一个二进制位来求子集和,逆变换拆回去的时候是一样的

code
void ort(int *a,int lim,int tp){//tp=1表示顺变换,tp=-1表示逆变换
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j+d]=(a[i+j+d]+a[i+j]*tp)%mod;
            }
}

&:与卷积

这应该叫做高维后缀和吧,我们可以用高维后缀和解决与卷积问题

变换就是后缀和,逆变换一样

code
void andt(int *a,int lim,int tp){//tp=1表示顺变换,tp=-1表示逆变换
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j]=(a[i+j]+a[i+j+d]*tp)%mod;
            }
}

^:异或卷积

这个应该叫高维后缀差分

好像有点难以概括,额,不概括了,直接看码吧

code
void xort(int *a,int lim,int tp){//tp=1表示顺变换,tp=0.5表示逆变换
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=a[i+j+d];
                a[i+j+d]=(a[i+j]-tmp+mod)*tp%mod;
                a[i+j]=(a[i+j]+tmp)*tp%mod;
            }
}

高维前缀和的求法,枚举每一维,做一遍前缀和

在位运算上也就是枚举每一个二进制位,将这一位为\(0\)的加到这一位为\(1\)的上面

这样做高维前缀和虽然朴素,但是有正确性的保证,保证了不重不漏

Luogu4717

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<18;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int n,lim,a[N],b[N],c[N],aa[N],bb[N];
void init(int lim){fo(i,0,lim-1)a[i]=aa[i],b[i]=bb[i];}
void ort(int *a,int lim,int tp){
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j+d]=(a[i+j+d]+a[i+j]*tp)%mod;
            }
}
void andt(int *a,int lim,int tp){
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                a[i+j]=(a[i+j]+a[i+j+d]*tp)%mod;
            }
}
void xort(int *a,int lim,int tp){
    for(int d=1;d<lim;d<<=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=a[i+j+d];
                a[i+j+d]=(a[i+j]-tmp+mod)*tp%mod;
                a[i+j]=(a[i+j]+tmp)*tp%mod;
            }
}
void mix(int lim){fo(i,0,lim-1)c[i]=a[i]*b[i]%mod;}
void pt(int *a,int lim){fo(i,0,lim-1)printf("%lld ",a[i]);printf("\n");}
signed main(){
    n=read();lim=1<<n;
    fo(i,0,lim-1)aa[i]=read();
    fo(i,0,lim-1)bb[i]=read();
    init(lim);ort(a,lim,1);ort(b,lim,1);mix(lim);ort(c,lim,mod-1);pt(c,lim);
    init(lim);andt(a,lim,1);andt(b,lim,1);mix(lim);andt(c,lim,mod-1);pt(c,lim);
    init(lim);xort(a,lim,1);xort(b,lim,1);mix(lim);xort(c,lim,ksm(2,mod-2));pt(c,lim);
}

分治FFT/NTT/FWT

这个东东就是把FFT和CDQ套在一起了

先处理左区间,处理当前区间,然后是右区间

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i++)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<20;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],lim,len,w[N];
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=w[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int n,f[N],g[N],a[N],b[N],c[N];
void mul(int n,int m,int v){
    for(lim=1,len=0;lim<n+m;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    w[0]=1;w[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    ntt(a,lim);ntt(b,lim);
    w[0]=1;w[1]=ksm(w[1],mod-2);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,lim-1)c[i]=a[i]*b[i]%mod;
    ntt(c,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)c[i]=c[i]*iv%mod,a[i]=b[i]=0;
}
void sol(int l,int r){
    // cout<<l<<" "<<r<<endl;
    if(l==r)return ;
    int mid=l+r>>1,ln=mid-l+1,lm=r-l+1;
    sol(l,mid);
    // cout<<"S "<<l<<" "<<r<<endl;
    fo(i,l,mid)a[i-l]=f[i];//cout<<a[i]<<" ";cout<<endl;
    fo(i,l,r)b[i-l]=g[i-l];//cout<<b[i]<<" ";cout<<endl;
    mul(ln,lm,lm);
    fo(i,ln,lm-1)f[i+l]=(f[i+l]+c[i])%mod;
    sol(mid+1,r);
}
signed main(){
    n=read();f[0]=1;
    fo(i,1,n-1)g[i]=read();
    sol(0,n-1);
    fo(i,0,n-1)printf("%lld ",f[i]);
}

多项式求逆

就是求某一个多项式的逆元\(mod\ x^n\)意义下

主要吧,这个板子还挺难背的

递推上来的证明过程

所以结论就是\(B=2B'-AB'^2\)(\(B'\)表示\(mod\ x^{\lceil\frac{n}{2}\rceil}\)意义下的逆元)

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<18;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],g[N],lim,len;
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=g[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int a[N],b[N],c[N];
void calc(int deg){//deg表示当前mod的是x^deg
    if(deg==1){return b[0]=ksm(a[0],mod-2),void();}
    calc((deg+1)>>1);
    for(lim=1,len=0;lim<deg*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    g[0]=1;g[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,deg-1)c[i]=a[i];//注意这里只赋值前deg-1位,因为这次乘法只能用到这些位
    //这里a有deg-1位,而b只有deg+1>>1位,因为b是要平方的,并且这里是求a的deg-1位的逆元
    ntt(c,lim);ntt(b,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)b[i]=(2-c[i]*b[i]%mod+mod)*b[i]%mod;
    ntt(b,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)b[i]=b[i]*iv%mod;
    fo(i,deg,lim-1)b[i]=0;//把后边的清空,防止影响下一次运算
    fo(i,0,lim-1)c[i]=0;//清空
}
int n;
signed main(){
    n=read();
    fo(i,0,n-1)a[i]=read();
    calc(n);
    fo(i,0,n-1)printf("%lld ",b[i]);
}

多项式开根

这个要用到上面的多项式求逆

这个就是求一个多项式的根号

于是乎我们仍然是递归求解

\(B=\frac{A+B^2}{2B}\) 证明

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<20;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],lim,len,g[N],iv2;
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=g[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int a[N],b[N],c[N];
void inv(int *x,int *y,int n){
    if(n==1){return y[0]=ksm(x[0],mod-2),void();}
    inv(x,y,n+1>>1);
    for(lim=1,len=0;lim<2*n;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    g[0]=1;g[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,n-1)a[i]=x[i];
    ntt(a,lim);ntt(y,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)y[i]=(2ll-a[i]*y[i]%mod+mod)*y[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,a[i]=0;
    fo(i,n,lim-1)y[i]=0;
}
void sqt(int *x,int *y,int n){
    if(n==1){return y[0]=1,void();}
    sqt(x,y,n+1>>1);inv(y,c,n);
    for(lim=1,len=0;lim<2*n;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    g[0]=1;g[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,n-1)a[i]=x[i];
    fo(i,0,lim-1)c[i]=c[i]*iv2%mod;
    ntt(a,lim);ntt(y,lim);ntt(c,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)y[i]=(a[i]+y[i]*y[i]%mod)*c[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,a[i]=b[i]=c[i]=0;
    fo(i,n,lim-1)y[i]=0;
}
int f[N],h[N];
signed main(){
    int n=read();iv2=ksm(2,mod-2);
    fo(i,0,n-1)f[i]=read();
    sqt(f,h,n);
    fo(i,0,n-1)printf("%lld ",h[i]);
}

多项式ln

求导然后积分回去,完事

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<20;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],lim,len,g[N];
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=g[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int a[N],b[N],c[N];
void inv(int *x,int *y,int n){
    if(n==1)return y[0]=ksm(x[0],mod-2),void();
    inv(x,y,n+1>>1);
    for(lim=1,len=0;lim<n*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    g[0]=1;g[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,n-1)a[i]=x[i];
    ntt(a,lim);ntt(y,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)y[i]=(2ll-a[i]*y[i]%mod+mod)*y[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,a[i]=0;
    fo(i,n,lim-1)y[i]=0;
}
int f[N],h[N],cf,ch;
void dec(int *x,int n){
    fo(i,0,n-1)x[i]=x[i+1]*(i+1)%mod;
}
void sum(int *x,int n){
    fu(i,n,1)x[i]=x[i-1]*ksm(i,mod-2)%mod;x[0]=0;
}
signed main(){
    cf=read();
    fo(i,0,cf-1)f[i]=read();
    inv(f,h,cf);dec(f,cf);
    for(lim=1,len=0;lim<=cf*2-1;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    g[0]=1;g[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    ntt(h,lim);ntt(f,lim);
    g[0]=1;g[1]=ksm(g[1],mod-2);
    fo(i,2,lim-1)g[i]=g[i-1]*g[1]%mod;
    fo(i,0,lim-1)h[i]=h[i]*f[i]%mod;
    ntt(h,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)h[i]=h[i]*iv%mod;
    sum(h,cf);fo(i,0,cf-1)printf("%lld ",h[i]);
}

多项式exp

牛顿迭代

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<20;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],lim,len,w[N];
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[i],a[af[i]]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=w[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int a[N],b[N],c[N],d[N],e[N];
void inv(int *x,int *y,int n){
    if(n==1)return y[0]=ksm(x[0],mod-2),void();
    inv(x,y,n+1>>1);
    for(lim=1,len=0;lim<n*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    w[0]=1;w[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,n-1)a[i]=x[i];
    ntt(a,lim);ntt(y,lim);
    w[0]=1;w[1]=ksm(w[1],mod-2);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,lim-1)y[i]=(2ll-a[i]*y[i]%mod+mod)*y[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,a[i]=0;
    fo(i,n,lim-1)y[i]=0;
}
void dec(int *a,int n){
    fo(i,0,n-2)a[i]=a[i+1]*(i+1)%mod;
    a[n-1]=0;
}
void sum(int *a,int n){
    fu(i,n-1,1)a[i]=a[i-1]*ksm(i,mod-2)%mod;
    a[0]=0;
}
void ln(int *x,int *y,int n){
    fo(i,0,n-1)c[i]=x[i];
    dec(c,n);inv(x,b,n);
    for(lim=1,len=0;lim<n*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    w[0]=1;w[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    ntt(c,lim);ntt(b,lim);
    w[0]=1;w[1]=ksm(w[1],mod-2);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,lim-1)y[i]=c[i]*b[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,c[i]=b[i]=0;
    sum(y,n);fo(i,n,lim-1)y[i]=0;
}
void exp(int *x,int *y,int n){
    if(n==1)return y[0]=1,void();
    exp(x,y,n+1>>1);ln(y,d,n);
    fo(i,0,n-1)e[i]=x[i];
    for(lim=1,len=0;lim<n*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    w[0]=1;w[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    ntt(e,lim);ntt(d,lim);ntt(y,lim);
    w[0]=1;w[1]=ksm(w[1],mod-2);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,lim-1)y[i]=(1ll-d[i]+e[i]+2*mod)*y[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,e[i]=d[i]=0;
    fo(i,n,lim-1)y[i]=0;
    // cout<<n<<" ";fo(i,0,n-1)cout<<y[i]<<" ";cout<<endl;
}
int n,f[N],g[N];
signed main(){
    n=read();
    fo(i,0,n-1)f[i]=read();
    exp(f,g,n);
    fo(i,0,n-1)printf("%lld ",g[i]);
}

多项式除法

翻转多项式之后可以取模干掉余数,于是可以直接求逆然后带回去

code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fo(i,x,y) for(int i=(x);i<=(y);i++)
#define fu(i,x,y) for(int i=(x);i>=(y);i--)
int read(){
    int s=0,t=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')t=-1;ch=getchar();}
    while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
    return s*t;
}
const int mod=998244353;
const int N=1<<19;
int ksm(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;y>>=1;
    }return ret;
}
int af[N],lim,len,w[N];
void ntt(int *a,int lim){
    fo(i,0,lim-1)if(af[i]>i)swap(a[af[i]],a[i]);
    for(int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
        for(int i=0;i<lim;i+=(d<<1))
            fo(j,0,d-1){
                int tmp=w[t*j]*a[i+j+d]%mod;
                a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
}
int a[N],b[N],c[N];
void inv(int *x,int *y,int n){
    if(n==1)return y[0]=ksm(x[0],mod-2),void();
    inv(x,y,n+1>>1);
    for(lim=1,len=0;lim<n*2;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    w[0]=1;w[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,n-1)a[i]=x[i];
    ntt(a,lim);ntt(y,lim);
    w[0]=1;w[1]=ksm(w[1],mod-2);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,lim-1)y[i]=(2ll-a[i]*y[i]%mod+mod)*y[i]%mod;
    ntt(y,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)y[i]=y[i]*iv%mod,a[i]=0;
    fo(i,n,lim-1)y[i]=0;
}
void mul(int *x,int *y,int n,int m,int *z,int v){
    for(lim=1,len=0;lim<n+m;lim<<=1,len++);
    fo(i,0,lim-1)af[i]=(af[i>>1]>>1)|((i&1)<<(len-1));
    w[0]=1;w[1]=ksm(3,(mod-1)/lim);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,n-1)a[i]=x[i];
    fo(i,0,m-1)b[i]=y[i];
    ntt(a,lim);ntt(b,lim);
    w[0]=1;w[1]=ksm(w[1],mod-2);
    fo(i,2,lim-1)w[i]=w[i-1]*w[1]%mod;
    fo(i,0,lim-1)z[i]=a[i]*b[i]%mod;
    ntt(z,lim);int iv=ksm(lim,mod-2);
    fo(i,0,lim-1)z[i]=z[i]*iv%mod,a[i]=b[i]=0;
    fo(i,v,lim-1)z[i]=0;
}
int n,m,f[N],g[N],ng[N],q[N],r[N];
signed main(){
    n=read();m=read();
    fo(i,0,n)f[i]=read();
    fo(i,0,m)g[i]=read();
    fo(i,0,n>>1)swap(f[i],f[n-i]);
    fo(i,0,m>>1)swap(g[i],g[m-i]);
    inv(g,ng,n-m+1);
    mul(f,ng,n+1,n-m+1,q,n-m+1);
    fo(i,0,n-m>>1)swap(q[i],q[n-m-i]);
    fo(i,0,n>>1)swap(f[i],f[n-i]);
    fo(i,0,m>>1)swap(g[i],g[m-i]);
    mul(q,g,n-m+1,m+1,ng,m);
    fo(i,0,m-1)r[i]=(f[i]-ng[i]+mod)%mod;
    fo(i,0,n-m)printf("%lld ",q[i]);printf("\n");
    fo(i,0,m-1)printf("%lld ",r[i]);
}

技巧

1、生成函数时,有关于对称的问题,如果对称的话,那么两个点的位置相加是对称中心的两倍,意思就是多项式卷积之后对称轴是同一条的都在一起

2、对于生成函数的利用,我们只需要系数,而并不是真正的把\(x\)带进去找值,我只需要系数就行了

比如:\(\sum_{i=0}^{n}\sum_{j=0}^{i}j!*2^{i-j}\),这个东西阶乘和幂可以分别看作是两个多项式的系数

这样我们可以通过卷积快速求出后面的和,也就是卷积后多项式的第\(i\)项系数,符合多项式乘法的规则,然后带入求和

posted @ 2022-01-25 19:39  fengwu2005  阅读(165)  评论(1编辑  收藏  举报