[清华集训2017]生成树计数
Ln(x) a[0]=1;
exp(x) a[0]=0;
11.26 换了一个更快的dft的板子
#include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=998244353; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; struct re{ int a,b,c; }; const int N=6e5; const int G=3; int f[N],g[N],n; struct fft{ int l,n,m; int a[N],b[N],inv[N]; int C[N],D[N]; fft() { inv[0]=inv[1]=1; rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; } IL void clear() { rep(i,0,n) a[i]=b[i]=0; } int ppr[N]; inline void ntt_init(){ int lg=0,x; for (x=1;x<=m;x*=2) lg++; n=x; ppr[0]=1;ppr[x]=fsp(31,1<<(21-lg)); for(int i=x>>1;i;i>>=1) ppr[i]=1ll*ppr[i<<1]*ppr[i<<1]%mo; for(int i=1;i<x;i++) ppr[i]=1ll*ppr[i&(i-1)]*ppr[i&-i]%mo; } inline int del(const int x){ return x>=mo?x-mo:x; } inline void DIF(int *f,const int x){ int len,hl,uni,*s,*i,*w; for(len=x,hl=x>>1;hl;len=hl,hl>>=1){ for(s=f,w=ppr;s!=f+x;s+=len,w++){ for(i=s;i<s+hl;i++){ uni=1ll**(i+hl)**w%mo; *(i+hl)=del(*i+mo-uni); *i=del(*i+uni); } } } } inline void DIT(int *f,const int x){ int len,hl,uni,*s,*i,*w; for(len=2,hl=1;len<=x;hl=len,len<<=1){ for(s=f,w=ppr;s!=f+x;s+=len,w++){ for(i=s;i!=s+hl;i++){ uni=*i; *i=del(uni+*(i+hl)); *(i+hl)=1ll*(uni+mo-*(i+hl))**w%mo; } } } reverse(f+1,f+x);int invx=mo-(mo-1)/x; for(i=f;i!=f+x;i++) *i=1ll**i*invx%mo; } IL void getcj(int *A,int *B,int len) { m=len*2; ntt_init(); for (int i=0;i<=len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; DIF(a,n); DIF(b,n); for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; DIT(a,n); for (int i=0;i<=len;i++) B[i]=a[i]; clear(); } IL void getcj(int *C,int len) { // for (int i=0;i<=len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; //m=len*2; m=len; ntt_init(); rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo; DIF(a,n); DIF(b,n); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; DIT(a,n); for (int i=0;i<n;i++) C[i]=a[i]; clear(); } IL void getinv(int *A,int *B,int len) { if (len==1) { B[0]=fsp(A[0],mo-2); return; } getinv(A,B,(len+1)>>1); m=len*2; ntt_init(); for (int i=0;i<=len;i++) a[i]=A[i],b[i]=B[i]; DIF(a,n); DIF(b,n); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo; DIT(a,n); for (int i=0;i<=len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; clear(); } IL void getsqrt(int *A,int *B,int len) { int inv2=fsp(2,mo-2); if (len==1) {B[0]=sqrt(A[0]); return;} getsqrt(A,B,(len+1)>>1); int C[N]={}; getinv(B,C,len); getcj(A,C,len); for (int i=0;i<=len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo; } IL void getDao(int *a,int *b,int len) { for (int i=1;i<=len;i++) b[i-1]=1ll*i*a[i]%mo; b[len-1]=0; } IL void getjf(int *a,int *b,int len) { for (int i=0;i<=len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo; b[0]=0; } IL void getln(int *A,int *B,int len) { // me(C); me(D); getDao(A,C,len); getinv(A,D,len); getcj(C,D,len); getjf(D,B,len); rep(i,0,len) C[i]=0,D[i]=0; } IL void getexp(int *A,int *B,int len) { if (len==1) {B[0]=1; return;} getexp(A,B,(len+1)>>1); int C[N]; getln(B,C,len); for(int i=0;i<=len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo; C[0]=(C[0]+1)%mo; getcj(C,B,len); } }F; /* f[i]=\sum f[j]*g[i-j]; */ /* int now[N]; void solve(int h,int t) { if (h>=t) return; if (t-h<=32) { rep(i,h,t) rep(j,h,i) f[i]=(f[i]+1ll*f[j]*g[i-j])%mo; return; } int mid=(h+t)/2; solve(h,mid); rep(i,h,mid) F.a[i-h]=f[i]; rep(i,1,t-h) F.b[i]=g[i]; F.getcj(now,(mid-h+1)+(t-h+1)); rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo; solve(mid+1,t); } */ int sum[N],now[N],a[N],b[N],c[N],d[N],e[N]; ll jc[N],jc2[N]; /* \prod (1+a[i]x) */ void solve(int h,int t,int *a) { if (h==t) return; int mid=(h+t)/2; solve(h,mid,a); solve(mid+1,t,a); rep(i,h,mid) F.a[i-h+1]=a[i]; rep(i,mid+1,t) F.b[i-mid]=a[i]; F.a[0]=F.b[0]=1; F.getcj(now,(mid-h+1)+(t-mid+1)); rep(i,h,t) a[i]=now[i-h+1]; } int sum3[N],sum4[N]; int main() { ios::sync_with_stdio(false); int n,m; cin>>n>>m; ll ans=1; rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo; rep(i,1,n) sum[i]=(-a[i]+mo)%mo; solve(1,n,sum); sum[0]=1; F.getln(sum,sum,n+2); F.getDao(sum,sum,n+2); dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo; sum[0]=n; jc[0]=jc2[0]=1; rep(i,1,n) jc[i]=jc[i-1]*i%mo; jc2[n]=fsp(jc[n],mo-2); dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo; rep(i,0,n) a[i]=b[i]=c[i]=0; rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo; rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo; F.getln(c,e,n+1); rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo; F.getexp(e,c,n+1); F.getinv(b,d,n+1); F.getcj(a,d,n+1); rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo; F.getcj(c,d,n+1); ans=ans*d[n-2]%mo*jc[n-2]%mo; cout<<ans<<endl; return 0; }
分治fft
#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math") #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native") #include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=998244353; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; struct re{ int a,b,c; }; const int N=6e5; const int G=3; struct fft{ int l,n,m; int a[N],b[N]; IL void clear() { rep(i,0,n) a[i]=b[i]=0; } int ppr[N]; inline void init(){ int lg=0,x; for (x=1;x<=m;x*=2) lg++; n=x; ppr[0]=1;ppr[x]=fsp(31,1<<(21-lg)); for(int i=x>>1;i;i>>=1) ppr[i]=1ll*ppr[i<<1]*ppr[i<<1]%mo; for(int i=1;i<x;i++) ppr[i]=1ll*ppr[i&(i-1)]*ppr[i&-i]%mo; } inline int del(const int x){ return x>=mo?x-mo:x; } inline void DIF(int *f,const int x){ int len,hl,uni,*s,*i,*w; for(len=x,hl=x>>1;hl;len=hl,hl>>=1){ for(s=f,w=ppr;s!=f+x;s+=len,w++){ for(i=s;i<s+hl;i++){ uni=1ll**(i+hl)**w%mo; *(i+hl)=del(*i+mo-uni); *i=del(*i+uni); } } } } inline void DIT(int *f,const int x){ int len,hl,uni,*s,*i,*w; for(len=2,hl=1;len<=x;hl=len,len<<=1){ for(s=f,w=ppr;s!=f+x;s+=len,w++){ for(i=s;i!=s+hl;i++){ uni=*i; *i=del(uni+*(i+hl)); *(i+hl)=1ll*(uni+mo-*(i+hl))**w%mo; } } } std::reverse(f+1,f+x);int invx=mo-(mo-1)/x; for(i=f;i!=f+x;i++) *i=1ll**i*invx%mo; } IL void getcj(int *A,int *B,int len) { m=len*2; init(); for (int i=0;i<=len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; DIF(a,n); DIF(b,n); for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; DIT(a,n); for (int i=0;i<=2*len;i++) B[i]=a[i]; clear(); } }F; int pos[N]; vector<int> ve[N]; #define mid ((h+t)/2) void build(int x,int h,int t) { if (h==t){ pos[h]=x; return;} build(x*2,h,mid); build(x*2+1,mid+1,t); } void gg(int x,int h,int t) { if (h==t) return; gg(x*2,h,mid); gg(x*2+1,mid+1,t); int nn=max(ve[x*2].size(),ve[x*2+1].size()); int i=0,j=0; rep(i,0,nn) f[i]=g[i]=0; for (auto v:ve[x*2]) { f[i]=v; i++; } for (auto v:ve[x*2+1]) { g[j]=v; j++; } F.getcj(f,g,nn); bool tt=0; dep(i,2*nn,0) if (g[i]!=0||tt) { tt=1; ve[x].push_back(g[i]); } reverse(ve[x].begin(),ve[x].end()); } int main() { ios::sync_with_stdio(false); build(1,1,n); rep(i,1,n) { ve[pos[i]].push_back(); } gg(1,1,n); return 0; }
代码:
#include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=998244353; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; struct re{ int a,b,c; }; const int N=6e5; const int G=3; int f[N],g[N],n; struct fft{ int l,n,m; int r[N],a[N],b[N],w[N],inv[N]; int C[N],D[N]; fft() { inv[0]=inv[1]=1; rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; } IL void ntt_init() { l=0; for (n=1;n<=m;n<<=1) l++; for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1)); } IL void clear() { rep(i,0,n) a[i]=b[i]=0; } void ntt(int *a,int o) { for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]); for (int i=1;i<n;i<<=1) { int wn=fsp(G,(mo-1)/(i*2)); w[0]=1; rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo; for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) { int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo; // if (x<0||y<0) cerr<<x<<" "<<y<<endl; a[j+k]=x+y>mo?x+y-mo:x+y; a[i+j+k]=x-y>=0?x-y:x-y+mo; // a[j+k]=(x+y)%mo; // a[i+j+k]=(x-y)%mo; } } if (o==-1) { reverse(&a[1],&a[n]); for (int i=0,inv=fsp(n,mo-2);i<n;i++) a[i]=1ll*a[i]*inv%mo; } } IL void getcj(int *C,int len) { // for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; m=len*2; ntt_init(); rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<n;i++) C[i]=a[i]; clear(); } IL void getcj(int *A,int *B,int len) { m=len*2; ntt_init(); for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; ntt(a,1); ntt(b,1); for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=a[i]; clear(); } IL void getinv(int *A,int *B,int len) { if (len==1) { B[0]=fsp(A[0],mo-2); return; } getinv(A,B,(len+1)>>1); m=len*2; ntt_init(); for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i]; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; clear(); } IL void getsqrt(int *A,int *B,int len) { int inv2=fsp(2,mo-2); if (len==1) {B[0]=sqrt(A[0]); return;} getsqrt(A,B,(len+1)>>1); int C[N]={}; getinv(B,C,len); getcj(A,C,len); for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo; } IL void getDao(int *a,int *b,int len) { for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo; b[len-1]=0; } IL void getjf(int *a,int *b,int len) { for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo; b[0]=0; } IL void getln(int *A,int *B,int len) { // me(C); me(D); getDao(A,C,len); getinv(A,D,len); getcj(C,D,len); getjf(D,B,len); rep(i,0,len) C[i]=0,D[i]=0; } IL void getexp(int *A,int *B,int len) { if (len==1) {B[0]=1; return;} getexp(A,B,(len+1)>>1); int C[N]; getln(B,C,len); for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo; C[0]=(C[0]+1)%mo; getcj(C,B,len); } }F; /* f[i]=\sum f[j]*g[i-j]; */ /* int now[N]; void solve(int h,int t) { if (h>=t) return; if (t-h<=32) { rep(i,h,t) rep(j,h,i) f[i]=(f[i]+1ll*f[j]*g[i-j])%mo; return; } int mid=(h+t)/2; solve(h,mid); rep(i,h,mid) F.a[i-h]=f[i]; rep(i,1,t-h) F.b[i]=g[i]; F.getcj(now,(t-h+1)+(mid-h+1)); rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo; solve(mid+1,t); } */ int sum[N],now[N],a[N],b[N],c[N],d[N],e[N]; ll jc[N],jc2[N]; /* \prod (1+a[i]x) */ void solve(int h,int t,int *a) { if (h==t) return; int mid=(h+t)/2; solve(h,mid,a); solve(mid+1,t,a); rep(i,h,mid) F.a[i-h+1]=a[i]; rep(i,mid+1,t) F.b[i-mid]=a[i]; F.a[0]=F.b[0]=1; F.getcj(now,(mid-h+2)); rep(i,h,t) a[i]=now[i-h+1]; } int sum3[N],sum4[N]; int main() { ios::sync_with_stdio(false); int n,m; cin>>n>>m; ll ans=1; rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo; rep(i,1,n) sum[i]=(-a[i]+mo)%mo; solve(1,n,sum); sum[0]=1; F.getln(sum,sum,n+2); F.getDao(sum,sum,n+2); dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo; sum[0]=n; jc[0]=jc2[0]=1; rep(i,1,n) jc[i]=jc[i-1]*i%mo; jc2[n]=fsp(jc[n],mo-2); dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo; rep(i,0,n) a[i]=b[i]=c[i]=0; rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo; rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo; F.getln(c,e,n+1); rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo; F.getexp(e,c,n+1); F.getinv(b,d,n+1); F.getcj(a,d,n+1); rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo; F.getcj(c,d,n+1); ans=ans*d[n-2]%mo*jc[n-2]%mo; cout<<ans<<endl; return 0; }
下面这个分治fft长度短会快一点 不懂原理
#include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=998244353; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; struct re{ int a,b,c; }; const int N=6e5; const int G=3; int f[N],g[N],n; struct fft{ int l,n,m; int r[N],a[N],b[N],w[N],inv[N]; int C[N],D[N]; fft() { inv[0]=inv[1]=1; rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; } IL void ntt_init() { l=0; for (n=1;n<=m;n<<=1) l++; for (int i=0;i<n;i++) r[i]=(r[i/2]/2)|((i&1)<<(l-1)); } IL void clear() { rep(i,0,n) a[i]=b[i]=0; } void ntt(int *a,int o) { for (int i=0;i<n;i++) if (i>r[i]) swap(a[i],a[r[i]]); for (int i=1;i<n;i<<=1) { int wn=fsp(G,(mo-1)/(i*2)); w[0]=1; rep(j,1,i-1) w[j]=(1ll*w[j-1]*wn)%mo; for (int j=0;j<n;j+=(i*2)) for (int k=0;k<i;k++) { int x=a[j+k],y=1ll*a[i+j+k]*w[k]%mo; // if (x<0||y<0) cerr<<x<<" "<<y<<endl; a[j+k]=x+y>mo?x+y-mo:x+y; a[i+j+k]=x-y>=0?x-y:x-y+mo; // a[j+k]=(x+y)%mo; // a[i+j+k]=(x-y)%mo; } } if (o==-1) { reverse(&a[1],&a[n]); for (int i=0,inv=fsp(n,mo-2);i<n;i++) a[i]=1ll*a[i]*inv%mo; } } IL void getcj(int *C,int len) { // for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; m=len*2; ntt_init(); rep(i,0,n) a[i]=(a[i]+mo)%mo,b[i]=(b[i]+mo)%mo; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<n;i++) C[i]=a[i]; clear(); } IL void getcj(int *A,int *B,int len) { m=len*2; ntt_init(); for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; ntt(a,1); ntt(b,1); for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=a[i]; clear(); } IL void getinv(int *A,int *B,int len) { if (len==1) { B[0]=fsp(A[0],mo-2); return; } getinv(A,B,(len+1)>>1); m=len*2; ntt_init(); for (int i=0;i<len;i++) a[i]=A[i],b[i]=B[i]; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; clear(); } IL void getsqrt(int *A,int *B,int len) { int inv2=fsp(2,mo-2); if (len==1) {B[0]=sqrt(A[0]); return;} getsqrt(A,B,(len+1)>>1); int C[N]={}; getinv(B,C,len); getcj(A,C,len); for (int i=0;i<len;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo; } IL void getDao(int *a,int *b,int len) { for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo; b[len-1]=0; } IL void getjf(int *a,int *b,int len) { for (int i=0;i<len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo; b[0]=0; } IL void getln(int *A,int *B,int len) { // me(C); me(D); getDao(A,C,len); getinv(A,D,len); getcj(C,D,len); getjf(D,B,len); rep(i,0,len) C[i]=0,D[i]=0; } IL void getexp(int *A,int *B,int len) { if (len==1) {B[0]=1; return;} getexp(A,B,(len+1)>>1); int C[N]; getln(B,C,len); for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo; C[0]=(C[0]+1)%mo; getcj(C,B,len); } }F; /* f[i]=\sum f[j]*g[i-j]; */ /* int now[N]; void solve(int h,int t) { if (h>=t) return; if (t-h<=32) { rep(i,h,t) rep(j,h,i) f[i]=(f[i]+1ll*f[j]*g[i-j])%mo; return; } int mid=(h+t)/2; solve(h,mid); rep(i,h,mid) F.a[i-h]=f[i]; rep(i,1,t-h) F.b[i]=g[i]; F.getcj(now,(t-h+1)); rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo; solve(mid+1,t); } */ int sum[N],now[N],a[N],b[N],c[N],d[N],e[N]; ll jc[N],jc2[N]; /* \prod (1+a[i]x) */ void solve(int h,int t,int *a) { if (h==t) return; int mid=(h+t)/2; solve(h,mid,a); solve(mid+1,t,a); rep(i,h,mid) F.a[i-h+1]=a[i]; rep(i,mid+1,t) F.b[i-mid]=a[i]; F.a[0]=F.b[0]=1; F.getcj(now,(mid-h+2)); rep(i,h,t) a[i]=now[i-h+1]; } int sum3[N],sum4[N]; int main() { ios::sync_with_stdio(false); int n,m; cin>>n>>m; ll ans=1; rep(i,1,n) cin>>a[i],ans=ans*a[i]%mo; rep(i,1,n) sum[i]=(-a[i]+mo)%mo; solve(1,n,sum); sum[0]=1; F.getln(sum,sum,n+2); F.getDao(sum,sum,n+2); dep(i,n,1) sum[i]=((mo-sum[i-1])%mo+mo)%mo; sum[0]=n; jc[0]=jc2[0]=1; rep(i,1,n) jc[i]=jc[i-1]*i%mo; jc2[n]=fsp(jc[n],mo-2); dep(i,n-1,1) jc2[i]=jc2[i+1]*(i+1)%mo; rep(i,0,n) a[i]=b[i]=c[i]=0; rep(i,0,n-1) a[i]=1ll*fsp(i+1,2*m)*jc2[i]%mo; rep(i,0,n-1) c[i]=b[i]=1ll*fsp(i+1,m)*jc2[i]%mo; F.getln(c,e,n+1); rep(i,0,n) e[i]=1ll*e[i]*sum[i]%mo; F.getexp(e,c,n+1); F.getinv(b,d,n+1); F.getcj(a,d,n+1); rep(i,0,n) d[i]=1ll*d[i]*sum[i]%mo; F.getcj(c,d,n+1); ans=ans*d[n-2]%mo*jc[n-2]%mo; cout<<ans<<endl; return 0; }