UOJ #449. 【集训队作业2018】喂鸽子
题解
warning:式子全都抄的题解。
我们可以先套一层\(\min-\max\)反演。
\[ans=\sum_{i=1}^n (-1)^{i-1}\binom{n}{i}g_i
\]
那么\(g_i\)就表示喂饱\(i\)只鸽子中至少一只的期望步数。
\[g_i=\sum_{i\geq 1}i*P(x=i)
\]
\[=\sum_{i\geq 1}P(x\geq i)
\]
然后考虑设计一个\(dp\),设\(f(sum,cnt)\)表示喂\(sum\)只鸽子,喂了\(cnt\)次,都没有喂饱的概率。
\[g_i=\sum_{j\geq 1}\sum_{s=0}^{i-1}\binom{i-1}{s}f(i,s)(\frac{n-i}{n}) ^{i-1-s}
\]
考虑枚举有一次喂食喂到了\(i\)只鸽子中,根据鸽巢原理,
\[g_i=\sum_{s=0}^{i(k-1)}f(i,s)\sum_{j \geq 0}\binom{s+j}{s}(\frac{n-i}{n})^j
\]
有一个不知道为什么的东西:
\[(\frac{1}{1-x})^k=\sum_{i\geq 0}\binom{i+k-1}{k-1}x^i
\]
那么:
\[\sum_{j\geq 0}\binom{s+t}{t}(\frac{n-c}{n})^t=(\frac{1}{1-\frac{n-c}{n}})^{s+1}=(\frac{n}{c})^{s+1}
\]
\[g_i=\sum_{s=0}^{i(k-1)}f(i,s)(\frac{n}{c})^{s+1}
\]
\[f(c,s)=\sum_{i=0}^{min(s,k-1)}\binom{s}{i}\frac{1}{n^i}f(c-1,s-i)
\]
\[\frac{f(c,s)}{s!}=\sum_{i=0}^{min(s,k-1)}\frac{1}{n^ii!}\frac{f(c-1,s-i)}{(s-i)!}
\]
然后就可以\(NTT\)算了。
代码
#include<bits/stdc++.h>
#define N 52
#define K 1002
#define M 68002
using namespace std;
typedef long long ll;
int n,k,rev[M];
ll dp[N][M],inv[M],jie[M],ni[M],ans,g[N];
const int G=3;
const int Gi=332748118;
const int mod=998244353;
inline ll rd(){
ll x=0;char c=getchar();bool f=0;
while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return f?-x:x;
}
inline ll power(ll x,ll y){
ll ans=1;
while(y){
if(y&1)ans=ans*x%mod;
x=x*x%mod;
y>>=1;
}
return ans;
}
inline void MOD(ll &x){x=x>=mod?x-mod:x;}
inline ll C(int n,int m){return jie[n]*ni[m]%mod*ni[n-m]%mod;}
inline void NTT(ll *a,int l,int tag){
for(int i=1;i<l;++i)if(i>rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<l;i<<=1){
ll wn=power(tag?G:Gi,(mod-1)/(i<<1));
for(int j=0;j<l;j+=(i<<1)){
ll w=1;
for(int k=0;k<i;++k,w=w*wn%mod){
ll x=a[j+k],y=a[i+j+k]*w%mod;
MOD(a[j+k]=x+y);MOD(a[i+j+k]=x-y+mod);
}
}
}
if(!tag){
ll ny=power(l,mod-2);
for(int i=0;i<l;++i)a[i]=a[i]*ny%mod;
}
}
inline void prework(int n){
jie[0]=1;
for(int i=1;i<=n;++i)jie[i]=jie[i-1]*i%mod;
ni[n]=power(jie[n],mod-2);
for(int i=n-1;i>=0;--i)ni[i]=ni[i+1]*(i+1)%mod;
}
int main(){
n=rd();k=rd();
prework(n*k);
for(int i=0;i<k;++i)inv[i]=power(power(n,i),mod-2)*ni[i]%mod;
int maxn=n*(k-1);
dp[0][0]=1;
int l=1,L=0;
while(l<=maxn)l<<=1,L++;
for(int i=1;i<l;++i)rev[i]=rev[i>>1]>>1|((i&1)<<(L-1));
NTT(dp[0],l,1);NTT(inv,l,1);
for(int i=1;i<=n;++i){
for(int j=0;j<l;++j)dp[i][j]=dp[i-1][j]*inv[j]%mod;
}
for(int i=1;i<=n;++i){
NTT(dp[i],l,0);
int x=i*(k-1);
ll nii=1ll*n*power(i,mod-2)%mod,num=1;
for(int j=0;j<=x;++j){
dp[i][j]=dp[i][j]*jie[j]%mod;
num=num*nii%mod;
MOD(g[i]+=dp[i][j]*num%mod);
}
}
for(int i=1;i<=n;++i){
if(i&1)MOD(ans+=C(n,i)*g[i]%mod);
else MOD(ans=ans-C(n,i)*g[i]%mod+mod);
}
cout<<ans;
return 0;
}