[清华集训2017]生成树计数

Ln(x) a[0]=1;

exp(x) a[0]=0;

11.26 换了一个更快的dft的板子

#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=998244353;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
const int N=6e5;
const int G=3;
int f[N],g[N],n;
struct fft{
  int l,n,m;
  int a[N],b[N],inv[N];
  int C[N],D[N];
  fft()
  {
    inv[0]=inv[1]=1;
    rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; 
  }
  IL void clear()
  {
      rep(i,0,n) a[i]=b[i]=0;
  }
  int ppr[N];
  inline void ntt_init(){
    int lg=0,x;
    for (x=1;x<=m;x*=2) lg++;
    n=x;
    ppr[0]=1;ppr[x]=fsp(31,1<<(21-lg));
    for(int i=x>>1;i;i>>=1) ppr[i]=1ll*ppr[i<<1]*ppr[i<<1]%mo;
    for(int i=1;i<x;i++) ppr[i]=1ll*ppr[i&(i-1)]*ppr[i&-i]%mo;
  }
  inline int del(const int x){
    return x>=mo?x-mo:x;
  }
  inline void DIF(int *f,const int x){
    int len,hl,uni,*s,*i,*w;
    for(len=x,hl=x>>1;hl;len=hl,hl>>=1){
        for(s=f,w=ppr;s!=f+x;s+=len,w++){
            for(i=s;i<s+hl;i++){
                uni=1ll**(i+hl)**w%mo;
                *(i+hl)=del(*i+mo-uni);
                *i=del(*i+uni);
            }
        }
    }
  }
  inline void DIT(int *f,const int x){
    int len,hl,uni,*s,*i,*w;
    for(len=2,hl=1;len<=x;hl=len,len<<=1){
        for(s=f,w=ppr;s!=f+x;s+=len,w++){
            for(i=s;i!=s+hl;i++){
                uni=*i;
                *i=del(uni+*(i+hl));
                *(i+hl)=1ll*(uni+mo-*(i+hl))**w%mo;
            }
        }
    }
    reverse(f+1,f+x);int invx=mo-(mo-1)/x;
    for(i=f;i!=f+x;i++) *i=1ll**i*invx%mo;
  }
  IL void getcj(int *A,int *B,int len)
  {
    m=len*2; ntt_init();
    for (int i=0;i<=len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;
    DIF(a,n); DIF(b,n);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    DIT(a,n);
    for (int i=0;i<=len;i++) B[i]=a[i];
    clear();
  }
  IL void getcj(int *C,int len)
  {
  //    for (int i=0;i<=len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;      
      //m=len*2; 
      m=len; ntt_init();
      rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo;
      DIF(a,n); DIF(b,n);
      for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
      DIT(a,n);
      for (int i=0;i<n;i++) C[i]=a[i];
      clear();
  }
  IL void getinv(int *A,int *B,int len)
  {
    if (len==1) { B[0]=fsp(A[0],mo-2); return; }
    getinv(A,B,(len+1)>>1);
    m=len*2; ntt_init();
    for (int i=0;i<=len;i++) a[i]=A[i],b[i]=B[i];
    DIF(a,n); DIF(b,n);
    for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo;
    DIT(a,n);
    for (int i=0;i<=len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; 
    clear();
  }
  IL void getsqrt(int *A,int *B,int len)
  {
    int inv2=fsp(2,mo-2);
    if (len==1) {B[0]=sqrt(A[0]); return;}
    getsqrt(A,B,(len+1)>>1);
    int C[N]={};
    getinv(B,C,len);
    getcj(A,C,len);
    for (int i=0;i<=len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo;
  }
  IL void getDao(int *a,int *b,int len)
  {
    for (int i=1;i<=len;i++) b[i-1]=1ll*i*a[i]%mo;
    b[len-1]=0;
  }
  IL void getjf(int *a,int *b,int len)
  {
    for (int i=0;i<=len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo;
    b[0]=0;
  }
  IL void getln(int *A,int *B,int len)
  {
  //  me(C); me(D);
    getDao(A,C,len);
    getinv(A,D,len);
    getcj(C,D,len);
    getjf(D,B,len);
    rep(i,0,len) C[i]=0,D[i]=0;
  }
  IL void getexp(int *A,int *B,int len)
  {
    if (len==1) {B[0]=1; return;}
    getexp(A,B,(len+1)>>1);
    int C[N];
    getln(B,C,len);
    for(int i=0;i<=len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo;
    C[0]=(C[0]+1)%mo;
    getcj(C,B,len);
  }
}F;
/*

f[i]=\sum f[j]*g[i-j]; 

*/
/*
int now[N];
void solve(int h,int t)
{
  if (h>=t) return; 
  if (t-h<=32)
  {
      rep(i,h,t)
        rep(j,h,i)
          f[i]=(f[i]+1ll*f[j]*g[i-j])%mo;
    return;
  }
  int mid=(h+t)/2;
  solve(h,mid);
  rep(i,h,mid) F.a[i-h]=f[i];
  rep(i,1,t-h) F.b[i]=g[i];
  F.getcj(now,(mid-h+1)+(t-h+1));
  rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo;
  solve(mid+1,t);
}
*/
int sum[N],now[N],a[N],b[N],c[N],d[N],e[N];
ll jc[N],jc2[N];
/*
\prod (1+a[i]x) 
*/ 
void solve(int h,int t,int *a)
{
    if (h==t) return;
    int mid=(h+t)/2;
    solve(h,mid,a); solve(mid+1,t,a);
    rep(i,h,mid) F.a[i-h+1]=a[i];
    rep(i,mid+1,t) F.b[i-mid]=a[i];
    F.a[0]=F.b[0]=1;
    F.getcj(now,(mid-h+1)+(t-mid+1));
    rep(i,h,t) a[i]=now[i-h+1];
}
int sum3[N],sum4[N];
int main()
{
   ios::sync_with_stdio(false);
   int n,m;
   cin>>n>>m;
   ll ans=1;
   rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo;
   rep(i,1,n) sum[i]=(-a[i]+mo)%mo;
   solve(1,n,sum);
   sum[0]=1;
   F.getln(sum,sum,n+2);
   F.getDao(sum,sum,n+2);
   dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo;
   sum[0]=n;
   jc[0]=jc2[0]=1;
   rep(i,1,n) jc[i]=jc[i-1]*i%mo;
   jc2[n]=fsp(jc[n],mo-2);
   dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo;
   rep(i,0,n) a[i]=b[i]=c[i]=0;
   rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo;
   rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo;
   F.getln(c,e,n+1);
   rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo;
   F.getexp(e,c,n+1);
   F.getinv(b,d,n+1);
   F.getcj(a,d,n+1);
   rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo;
   F.getcj(c,d,n+1);
   ans=ans*d[n-2]%mo*jc[n-2]%mo;
   cout<<ans<<endl; 
   return 0;
}
View Code

 

 分治fft

#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=998244353;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
const int N=6e5;
const int G=3;
struct fft{
    int l,n,m;
    int a[N],b[N];
    IL void clear()
    {
        rep(i,0,n) a[i]=b[i]=0;
    }
    int ppr[N];
    inline void init(){
        int lg=0,x;
        for (x=1;x<=m;x*=2) lg++;
        n=x;
        ppr[0]=1;ppr[x]=fsp(31,1<<(21-lg));
        for(int i=x>>1;i;i>>=1) ppr[i]=1ll*ppr[i<<1]*ppr[i<<1]%mo;
        for(int i=1;i<x;i++) ppr[i]=1ll*ppr[i&(i-1)]*ppr[i&-i]%mo;
    }
    inline int del(const int x){
        return x>=mo?x-mo:x;
    }
    inline void DIF(int *f,const int x){
        int len,hl,uni,*s,*i,*w;
        for(len=x,hl=x>>1;hl;len=hl,hl>>=1){
            for(s=f,w=ppr;s!=f+x;s+=len,w++){
                for(i=s;i<s+hl;i++){
                    uni=1ll**(i+hl)**w%mo;
                    *(i+hl)=del(*i+mo-uni);
                    *i=del(*i+uni);
                }
            }
        }
    }
    inline void DIT(int *f,const int x){
        int len,hl,uni,*s,*i,*w;
        for(len=2,hl=1;len<=x;hl=len,len<<=1){
            for(s=f,w=ppr;s!=f+x;s+=len,w++){
                for(i=s;i!=s+hl;i++){
                    uni=*i;
                    *i=del(uni+*(i+hl));
                    *(i+hl)=1ll*(uni+mo-*(i+hl))**w%mo;
                }
            }
        } std::reverse(f+1,f+x);int invx=mo-(mo-1)/x;
        for(i=f;i!=f+x;i++) *i=1ll**i*invx%mo;
    }
    IL void getcj(int *A,int *B,int len)
    {  
        m=len*2; init();
        for (int i=0;i<=len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;
        DIF(a,n); DIF(b,n);
        for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
        DIT(a,n);
        for (int i=0;i<=2*len;i++) B[i]=a[i];
        clear();
    }
}F;
int pos[N];
vector<int> ve[N];
#define mid ((h+t)/2)
void build(int x,int h,int t)
{
    if (h==t){ pos[h]=x; return;}
    build(x*2,h,mid); build(x*2+1,mid+1,t); 
}
void gg(int x,int h,int t)
{
    if (h==t) return;
    gg(x*2,h,mid); gg(x*2+1,mid+1,t);
    int nn=max(ve[x*2].size(),ve[x*2+1].size());
    int i=0,j=0;
    rep(i,0,nn) f[i]=g[i]=0;
    for (auto v:ve[x*2])
    {
        f[i]=v; i++;
    }
    for (auto v:ve[x*2+1])
    {
        g[j]=v; j++;
    }
    F.getcj(f,g,nn);
    bool tt=0;
    dep(i,2*nn,0)
      if (g[i]!=0||tt)
      {
          tt=1;
          ve[x].push_back(g[i]);
      }
    reverse(ve[x].begin(),ve[x].end()); 
}
int main()
{
    ios::sync_with_stdio(false);
    build(1,1,n);
    rep(i,1,n)
    {
        ve[pos[i]].push_back();
    }
    gg(1,1,n); 
    return 0;
}

 

代码:

#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=998244353;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
const int N=6e5;
const int G=3;
int f[N],g[N],n;
struct fft{
  int l,n,m;
  int r[N],a[N],b[N],w[N],inv[N];
  int C[N],D[N];
  fft()
  {
    inv[0]=inv[1]=1;
    rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; 
  }
  IL void ntt_init()
  {
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1)); 
  }
  IL void clear()
  {
      rep(i,0,n) a[i]=b[i]=0;
  }
  void ntt(int *a,int o)
  {
      for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
      for (int i=1;i<n;i<<=1)
      {
          int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
          rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
          for (int j=0;j<n;j+=(i*2))
            for (int k=0;k<i;k++)
            {
                int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
            //    if (x<0||y<0) cerr<<x<<" "<<y<<endl; 
                a[j+k]=x+y>mo?x+y-mo:x+y; 
            a[i+j+k]=x-y>=0?x-y:x-y+mo;
        //     a[j+k]=(x+y)%mo;
        //     a[i+j+k]=(x-y)%mo;
            }
    }
    if (o==-1)
    {
        reverse(&a[1],&a[n]);
        for (int i=0,inv=fsp(n,mo-2);i<n;i++)
           a[i]=1ll*a[i]*inv%mo;
    }
  }
  IL void getcj(int *C,int len)
  {
  //    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;      
      m=len*2; ntt_init();
      rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo;
      ntt(a,1); ntt(b,1);
      for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
      ntt(a,-1);
      for (int i=0;i<n;i++) C[i]=a[i];
      clear();
  }
  IL void getcj(int *A,int *B,int len)
  {
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;
    ntt(a,1); ntt(b,1);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=a[i];
    clear();
  }
  IL void getinv(int *A,int *B,int len)
  {
    if (len==1) { B[0]=fsp(A[0],mo-2); return; }
    getinv(A,B,(len+1)>>1);
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; 
    clear();
  }
  IL void getsqrt(int *A,int *B,int len)
  {
    int inv2=fsp(2,mo-2);
    if (len==1) {B[0]=sqrt(A[0]); return;}
    getsqrt(A,B,(len+1)>>1);
    int C[N]={};
    getinv(B,C,len);
    getcj(A,C,len);
    for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo;
  }
  IL void getDao(int *a,int *b,int len)
  {
    for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo;
    b[len-1]=0;
  }
  IL void getjf(int *a,int *b,int len)
  {
    for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo;
    b[0]=0;
  }
  IL void getln(int *A,int *B,int len)
  {
  //  me(C); me(D);
    getDao(A,C,len);
    getinv(A,D,len);
    getcj(C,D,len);
    getjf(D,B,len);
    rep(i,0,len) C[i]=0,D[i]=0;
  }
  IL void getexp(int *A,int *B,int len)
  {
    if (len==1) {B[0]=1; return;}
    getexp(A,B,(len+1)>>1);
    int C[N];
    getln(B,C,len);
    for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo;
    C[0]=(C[0]+1)%mo;
    getcj(C,B,len);
  }
}F;
/*

f[i]=\sum f[j]*g[i-j]; 

*/
/*
int now[N];
void solve(int h,int t)
{
  if (h>=t) return; 
  if (t-h<=32)
  {
      rep(i,h,t)
        rep(j,h,i)
          f[i]=(f[i]+1ll*f[j]*g[i-j])%mo;
    return;
  }
  int mid=(h+t)/2;
  solve(h,mid);
  rep(i,h,mid) F.a[i-h]=f[i];
  rep(i,1,t-h) F.b[i]=g[i];
  F.getcj(now,(t-h+1)+(mid-h+1));
  rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo;
  solve(mid+1,t);
}
*/
int sum[N],now[N],a[N],b[N],c[N],d[N],e[N];
ll jc[N],jc2[N];
/*
\prod (1+a[i]x) 
*/ 
void solve(int h,int t,int *a)
{
    if (h==t) return;
    int mid=(h+t)/2;
    solve(h,mid,a); solve(mid+1,t,a);
    rep(i,h,mid) F.a[i-h+1]=a[i];
    rep(i,mid+1,t) F.b[i-mid]=a[i];
    F.a[0]=F.b[0]=1;
    F.getcj(now,(mid-h+2));
    rep(i,h,t) a[i]=now[i-h+1];
}
int sum3[N],sum4[N];
int main()
{
   ios::sync_with_stdio(false);
   int n,m;
   cin>>n>>m;
   ll ans=1;
   rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo;
   rep(i,1,n) sum[i]=(-a[i]+mo)%mo;
   solve(1,n,sum);
   sum[0]=1;
   F.getln(sum,sum,n+2);
   F.getDao(sum,sum,n+2);
   dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo;
   sum[0]=n;
   jc[0]=jc2[0]=1;
   rep(i,1,n) jc[i]=jc[i-1]*i%mo;
   jc2[n]=fsp(jc[n],mo-2);
   dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo;
   rep(i,0,n) a[i]=b[i]=c[i]=0;
   rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo;
   rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo;
   F.getln(c,e,n+1);
   rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo;
   F.getexp(e,c,n+1);
   F.getinv(b,d,n+1);
   F.getcj(a,d,n+1);
   rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo;
   F.getcj(c,d,n+1);
   ans=ans*d[n-2]%mo*jc[n-2]%mo;
   cout<<ans<<endl; 
   return 0;
}
View Code

 下面这个分治fft长度短会快一点 不懂原理

#include <bits/stdc++.h>
using namespace std;
#define rep(i,h,t) for (int i=h;i<=t;i++)
#define dep(i,t,h) for (int i=t;i>=h;i--)
#define ll long long
#define me(x) memset(x,0,sizeof(x))
#define IL inline
#define rint register int
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
char ss[1<<24],*A=ss,*B=ss;
IL char gc()
{
    return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++;
}
template<class T>void maxa(T &x,T y)
{
    if (y>x) x=y;
}
template<class T>void mina(T &x,T y)
{
    if (y<x) x=y;
}
template<class T>void read(T &x)
{
    int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48);
    while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f;
}
const int mo=998244353;
ll fsp(int x,int y)
{
    if (y==1) return x;
    ll ans=fsp(x,y/2);
    ans=ans*ans%mo;
    if (y%2==1) ans=ans*x%mo;
    return ans;
}
struct cp {
    ll x,y;
    cp operator +(cp B)
    {
        return (cp){x+B.x,y+B.y};
    }
    cp operator -(cp B)
    {
        return (cp){x-B.x,y-B.y};
    }
    ll operator *(cp B)
    {
        return x*B.y-y*B.x;
    }
    int half() { return y < 0 || (y == 0 && x < 0); }
};
struct re{
    int a,b,c;
};
const int N=6e5;
const int G=3;
int f[N],g[N],n;
struct fft{
  int l,n,m;
  int r[N],a[N],b[N],w[N],inv[N];
  int C[N],D[N];
  fft()
  {
    inv[0]=inv[1]=1;
    rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; 
  }
  IL void ntt_init()
  {
    l=0; for (n=1;n<=m;n<<=1) l++;
    for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1)); 
  }
  IL void clear()
  {
      rep(i,0,n) a[i]=b[i]=0;
  }
  void ntt(int *a,int o)
  {
      for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]);
      for (int i=1;i<n;i<<=1)
      {
          int wn=fsp(G,(mo-1)/(i*2)); w[0]=1;
          rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo;
          for (int j=0;j<n;j+=(i*2))
            for (int k=0;k<i;k++)
            {
                int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo;
            //    if (x<0||y<0) cerr<<x<<" "<<y<<endl; 
                a[j+k]=x+y>mo?x+y-mo:x+y; 
            a[i+j+k]=x-y>=0?x-y:x-y+mo;
        //     a[j+k]=(x+y)%mo;
        //     a[i+j+k]=(x-y)%mo;
            }
    }
    if (o==-1)
    {
        reverse(&a[1],&a[n]);
        for (int i=0,inv=fsp(n,mo-2);i<n;i++)
           a[i]=1ll*a[i]*inv%mo;
    }
  }
  IL void getcj(int *C,int len)
  {
  //    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;      
      m=len*2; ntt_init();
      rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo;
      ntt(a,1); ntt(b,1);
      for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
      ntt(a,-1);
      for (int i=0;i<n;i++) C[i]=a[i];
      clear();
  }
  IL void getcj(int *A,int *B,int len)
  {
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo;
    ntt(a,1); ntt(b,1);
    for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=a[i];
    clear();
  }
  IL void getinv(int *A,int *B,int len)
  {
    if (len==1) { B[0]=fsp(A[0],mo-2); return; }
    getinv(A,B,(len+1)>>1);
    m=len*2; ntt_init();
    for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i];
    ntt(a,1); ntt(b,1);
    for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo;
    ntt(a,-1);
    for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; 
    clear();
  }
  IL void getsqrt(int *A,int *B,int len)
  {
    int inv2=fsp(2,mo-2);
    if (len==1) {B[0]=sqrt(A[0]); return;}
    getsqrt(A,B,(len+1)>>1);
    int C[N]={};
    getinv(B,C,len);
    getcj(A,C,len);
    for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo;
  }
  IL void getDao(int *a,int *b,int len)
  {
    for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo;
    b[len-1]=0;
  }
  IL void getjf(int *a,int *b,int len)
  {
    for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo;
    b[0]=0;
  }
  IL void getln(int *A,int *B,int len)
  {
  //  me(C); me(D);
    getDao(A,C,len);
    getinv(A,D,len);
    getcj(C,D,len);
    getjf(D,B,len);
    rep(i,0,len) C[i]=0,D[i]=0;
  }
  IL void getexp(int *A,int *B,int len)
  {
    if (len==1) {B[0]=1; return;}
    getexp(A,B,(len+1)>>1);
    int C[N];
    getln(B,C,len);
    for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo;
    C[0]=(C[0]+1)%mo;
    getcj(C,B,len);
  }
}F;
/*

f[i]=\sum f[j]*g[i-j]; 

*/
/*
int now[N];
void solve(int h,int t)
{
  if (h>=t) return; 
  if (t-h<=32)
  {
      rep(i,h,t)
        rep(j,h,i)
          f[i]=(f[i]+1ll*f[j]*g[i-j])%mo;
    return;
  }
  int mid=(h+t)/2;
  solve(h,mid);
  rep(i,h,mid) F.a[i-h]=f[i];
  rep(i,1,t-h) F.b[i]=g[i];
  F.getcj(now,(t-h+1));
  rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo;
  solve(mid+1,t);
}
*/
int sum[N],now[N],a[N],b[N],c[N],d[N],e[N];
ll jc[N],jc2[N];
/*
\prod (1+a[i]x) 
*/ 
void solve(int h,int t,int *a)
{
    if (h==t) return;
    int mid=(h+t)/2;
    solve(h,mid,a); solve(mid+1,t,a);
    rep(i,h,mid) F.a[i-h+1]=a[i];
    rep(i,mid+1,t) F.b[i-mid]=a[i];
    F.a[0]=F.b[0]=1;
    F.getcj(now,(mid-h+2));
    rep(i,h,t) a[i]=now[i-h+1];
}
int sum3[N],sum4[N];
int main()
{
   ios::sync_with_stdio(false);
   int n,m;
   cin>>n>>m;
   ll ans=1;
   rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo;
   rep(i,1,n) sum[i]=(-a[i]+mo)%mo;
   solve(1,n,sum);
   sum[0]=1;
   F.getln(sum,sum,n+2);
   F.getDao(sum,sum,n+2);
   dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo;
   sum[0]=n;
   jc[0]=jc2[0]=1;
   rep(i,1,n) jc[i]=jc[i-1]*i%mo;
   jc2[n]=fsp(jc[n],mo-2);
   dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo;
   rep(i,0,n) a[i]=b[i]=c[i]=0;
   rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo;
   rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo;
   F.getln(c,e,n+1);
   rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo;
   F.getexp(e,c,n+1);
   F.getinv(b,d,n+1);
   F.getcj(a,d,n+1);
   rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo;
   F.getcj(c,d,n+1);
   ans=ans*d[n-2]%mo*jc[n-2]%mo;
   cout<<ans<<endl; 
   return 0;
}

 

posted @ 2021-08-15 16:03  尹吴潇  阅读(58)  评论(0编辑  收藏  举报