[LOJ#3120][Luogu5401][CTS2019]珍珠(容斥+生成函数)
https://www.luogu.org/blog/user50971/solution-p5401
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 4 using namespace std; 5 6 const int N=400010,mod=998244353,i2=499122177; 7 int D,n,m,ans,fac[N],inv[N],ip2[N],f[N],g[N],rev[N],a[N],b[N]; 8 9 int ksm(int a,int b){ 10 int res=1; 11 for (; b; a=1ll*a*a%mod,b>>=1) 12 if (b & 1) res=1ll*res*a%mod; 13 return res; 14 } 15 16 int C(int n,int m){ return n<m ? 0 : 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod; } 17 18 void NTT(int a[],int n,bool f){ 19 for (int i=0; i<n; i++) if (i<rev[i]) swap(a[i],a[rev[i]]); 20 for (int i=1; i<n; i<<=1){ 21 int wn=ksm(3,f?(mod-1)/(i<<1):(mod-1)-(mod-1)/(i<<1)); 22 for (int p=i<<1,j=0; j<n; j+=p){ 23 int w=1; 24 for (int k=0; k<i; k++,w=1ll*w*wn%mod){ 25 int x=a[j+k],y=1ll*w*a[i+j+k]%mod; 26 a[j+k]=(x+y)%mod; a[i+j+k]=(x-y+mod)%mod; 27 } 28 } 29 } 30 if (f) return; 31 int inv=ksm(n,mod-2); 32 for (int i=0; i<n; i++) a[i]=1ll*a[i]*inv%mod; 33 } 34 35 int main(){ 36 freopen("pearl.in","r",stdin); 37 freopen("pearl.out","w",stdout); 38 scanf("%d%d%d",&D,&n,&m); 39 if (n-2*m<0){ puts("0"); return 0; } 40 if (n-2*m>=D){ printf("%d\n",ksm(D,n)); return 0; } 41 fac[0]=ip2[0]=1; 42 rep(i,1,D) fac[i]=1ll*fac[i-1]*i%mod,ip2[i]=1ll*ip2[i-1]*i2%mod; 43 inv[D]=ksm(fac[D],mod-2); 44 for (int i=D; i; i--) inv[i-1]=1ll*inv[i]*i%mod; 45 46 int l=1,L=0,t=1; 47 while (l<=(D<<1)) l<<=1,L++; 48 for (int i=0; i<l; i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); 49 rep(i,0,D) a[i]=1ll*t*ksm(D-2*i+mod+mod,n)%mod*inv[i]%mod,b[i]=inv[i],t=mod-t; 50 NTT(a,l,1); NTT(b,l,1); 51 for (int i=0; i<l; i++) a[i]=1ll*a[i]*b[i]%mod; 52 NTT(a,l,0); 53 rep(i,0,D) f[i]=1ll*ip2[i]*fac[i]%mod*C(D,i)%mod*a[i]%mod; 54 for (int i=0; i<l; i++) a[i]=b[i]=0; 55 56 t=(D&1)?mod-1:1; 57 rep(i,0,D) a[i]=1ll*f[i]*fac[i]%mod,b[i]=1ll*t*inv[D-i]%mod,t=mod-t; 58 NTT(a,l,1); NTT(b,l,1); 59 for (int i=0; i<l; i++) a[i]=1ll*a[i]*b[i]%mod; 60 NTT(a,l,0); 61 rep(i,0,D) g[i]=1ll*a[D+i]*inv[i]%mod; 62 rep(i,0,n-2*m) ans=(ans+g[i])%mod; 63 printf("%d\n",ans); 64 return 0; 65 }