【题解】2021牛客暑期多校第四场 G.Product
题意
给定\(n,k,D\),求
\[\sum_{a_i\ge0,\sum_{i=1}^n a_i=D}\frac{D!}{\prod_{i=1}^n(a_i+k)!}
\]
\(1\le n\le 50,0\le k\le 50,0\le D\le 10^8\)
题解
由于\(e^x=\sum_{k=0}^{\infty}\frac{x^k}{k!}\),所以有
\[\sum_{a_i\ge0,\sum_{i=1}^n a_i=D}\frac{1}{\prod_{i=1}^na_i!}=[x^D](e^x)^n=[x^D]e^{nx}
\]
设\(b_i=a_i+k\),则原问题化为求
\[\sum_{b_i\ge k,\sum_{i=1}^n b_i=D+nk}\frac{D!}{\prod_{i=1}^nb_i!}
\]
考虑\(b_i\ge k\)的限制,只需将次数\(<k\)的项去掉,所以
\[\sum_{b_i\ge k,\sum_{i=1}^n b_i=D+nk}\frac{1}{\prod_{i=1}^nb_i!}=[x^{D+nk}](e^x-\sum_{i=0}^{k-1}\frac{x^i}{i!})^n
\]
由于\(n,k\)很小,上式可以直接暴力计算,记\(A(x)=-\sum_{i=0}^{k-1}\frac{x^i}{i!}\),预处理\(A(x),A^2(x),...,A^n(x)\),再根据\((e^x+A(x))^n=\sum_{i=0}^{n}\binom{n}{i}A^i(x)e^{(n-i)x}\)即可算出答案。计算过程中出现的\(\frac{1}{(D+nk-j)!},j\in[0,A.len]\)无法快速计算,但由于整体乘上了\(D!\),而\(\frac{D!}{(D+nk-j)!}\)可以快速计算。最终复杂度\(O(n^2k^2)\)。
#include <bits/stdc++.h>
#define pb(x) emplace_back(x)
using namespace std;
const int N=2510;
using ll=long long;
const ll M=998244353;
int n;
ll s[N],sv[N],D,D2;
inline void MOD(ll&x){if(x>=M)x%=M;}
ll pm(ll x,ll b){ll res=1;while(b){if(b&1)res=res*x%M;x=x*x%M;b>>=1;}return res;}
ll inv(ll x){return pm(x,M-2);}
struct pol{
int len;
ll a[N];
ll& operator[](size_t x){return a[x];}
}ps[52];
void mul(pol& a,pol& b,pol& c){
c.len=a.len+b.len;
for(int k=0;k<=c.len;k++){
c[k]=0;
for(int i=0;i<=k;i++){
c[k]+=a[i]*b[k-i]%M;
}
MOD(c[k]);
}
}
ll C(ll n,ll m){return s[n]*sv[n-m]%M*sv[m]%M;}
//求(x!)/(y!)
ll cal1(ll x,ll y){
if(x==y)return 1;
if(x>y){
ll res=1;
for(ll i=y+1;i<=x;i++)res=res*i%M;
return res;
}
else{
ll res=1;
for(ll i=x+1;i<=y;i++){res=res*i%M;}
return inv(res);
}
}
ll cal2(ll x,ll y){
if(y==0)return 1;
if(x==0)return 0;
ll res=pm(x,y);
res*=cal1(D,y);
return res%M;
}
void f1(){
int k;
scanf("%d%d%lld",&n,&k,&D);
D2=D+n*k;
s[0]=1;sv[0]=1;
for(int i=1;i<=50;i++){s[i]=s[i-1]*i%M;}
sv[50]=inv(s[50]);
for(int i=49;i>=1;i--){sv[i]=sv[i+1]*(i+1)%M;}
ps[0][0]=1;ps[0].len=0;
ps[1].len=k-1;
for(int i=0;i<k;i++){ps[1][i]=M-sv[i];}
for(int i=2;i<=n;i++){mul(ps[i-1],ps[1],ps[i]);}
ll ans=0,tmp=0;
for(int i=0;i<=n;i++){
tmp=0;
for(int j=0,l=min<ll>(ps[i].len,D2);j<=l;j++){
tmp+=cal2(n-i,D2-j)*ps[i][j]%M;
if(tmp>=M)tmp-=M;
}
ans+=tmp*C(n,i)%M;
if(ans>=M)ans-=M;
}
printf("%lld",ans);
}
int main(){
f1();
return 0;
}