UOJ #449. 【集训队作业2018】喂鸽子

UOJ #449. 【集训队作业2018】喂鸽子

小Z是养鸽子的人。一天,小Z给鸽子们喂玉米吃。一共有n只鸽子,小Z每秒会等概率选择一只鸽子并给他一粒玉米。一只鸽子饱了当且仅当它吃了的玉米粒数量\(≥k\)。 小Z想要你告诉他,期望多少秒之后所有的鸽子都饱了。

假设答案的最简分数形式为\(\frac{a}{b}\),你需要求出\(w\),满足\(a≡b⋅w \pmod{998244353}(0≤w<998244353).\)

\(n\leq 50,k\leq 1000\)

Orz

首先可以用\(\min-\max\)反演来解决:

因为\(k\)是固定的,所以每个集合中至少有一个鸽子被喂饱的期望只与集合大小有关。

\[ans=\sum_{i=1}^n(-1)^{i+1}\binom{n}{i}g_i \]

其中\(g_c\)就是至少喂饱\(c\)只鸽子中的一只的期望步数。

我们将期望转成概率:

\[\begin{align} g_c&=\sum_{i\geq 1}i*P(x=i)\\ &=\sum_{i\geq 1}P(x\geq i)\\ \end{align} \]

\(f_{c,s}\)表示给\(c\)只鸽子喂食,喂了\(s\)次还没有将任意一只鸽子喂饱的概率。

所以:

\[\begin{align} g_c&=\sum_{i\geq 1}\sum_{s=0}^{i-1}\binom{i-1}{s}f_{c,s}(\frac{n-c}{n})^{i-1-s}\\ &=\sum_{s=0}^{c(k-1)}f_{c,s}\sum_{t\geq 0}\binom{s+t}{s}(\frac{n-c}{n})^t \end{align} \]

我们知道:

\[(\frac{1}{1-x})^k=\sum_{i\geq 0}\binom{i+k-1}{k-1} x^i \]

所以:

\[\begin{align} \sum_{t\geq 0}\binom{s+t}{t}(\frac{n-c}{n})^t&=(\frac{1}{1-\frac{n-c}{n}})^{s+1}\\ &=(\frac{n}{c})^{s+1} \end{align} \]

所以

\[g_c=\sum_{s=0}^{c(k-1)}f_{c,s}(\frac{n}{c})^{s+1} \]

接着考虑求\(f\)数组。

方法就是新加进来一只鸽子就枚举给这只鸽子喂了多少次食物。

\[f_{c,s}=\sum_{i=0}^{\min(s,k-1)}\binom{s}{i}\frac{1}{n^i}f_{c-1,s-i}\\ \frac{f_{c,s}}{s!}= \sum_{i=0}^{\min(s,k-1)} \frac{1}{n^ii!} \frac{ f_{c-1,s-i}}{(s-i!)} \\ \]

于是就可以用\(NTT\)算出\(\frac{f_{c,s}}{s!}\)的值了。

复杂度\(O(n^2klog(k))\)

还有个\(O(n^2k)\)的算法就先咕着吧。

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 55
#define K 1005

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

const ll mod=998244353;

ll ksm(ll t,ll x) {
	ll ans=1;
	for(;x;x>>=1,t=t*t%mod)
		if(x&1) ans=ans*t%mod;
	return ans;
}

int n,k,m;
int f[N][N*K];
ll fac[N*K],ifac[N*K];
ll C(int n,int m) {return fac[n]*ifac[m]%mod*ifac[n-m]%mod;}

void NTT(ll *a,int d,int flag) {
	static int rev[N*K<<2];
	static ll G=3;
	int n=1<<d;
	for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
	for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int s=1;s<=d;s++) {
		int len=1<<s,mid=len>>1;
		ll w=flag==1?ksm(G,(mod-1)/len):ksm(G,mod-1-(mod-1)/len);
		for(int i=0;i<n;i+=len) {
			ll t=1;
			for(int j=0;j<mid;j++,t=t*w%mod) {
				ll u=a[i+j],v=a[i+j+mid]*t%mod;
				a[i+j]=(u+v)%mod;
				a[i+j+mid]=(u-v+mod)%mod;
			}
		}
	}
	if(flag==-1) {
		ll inv=ksm(n,mod-2);
		for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
	}
}

ll A[N*K<<2];
ll B[N*K<<2];
ll g[N];

int main() {
	n=Get(),k=Get();
	m=n*k;
	int d=ceil(log2(m+1));
	fac[0]=1;
	for(int i=1;i<=m;i++) fac[i]=fac[i-1]*i%mod;
	ifac[m]=ksm(fac[m],mod-2);
	for(int i=m-1;i>=0;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
	ll invn=ksm(n,mod-2);
	for(int i=0;i<k;i++) {
		A[i]=ksm(invn,i)*ifac[i]%mod;
	}
	
	NTT(A,d,1);
	for(int i=0;i<k;i++) f[1][i]=1;
	
	for(int i=0;i<k;i++) f[1][i]=ksm(invn,i);
	for(int i=2;i<=n;i++) {
		for(int j=0;j<1<<d;j++) B[j]=0;
		for(int j=0;j<=i*(k-1);j++) B[j]=f[i-1][j]*ifac[j];
		NTT(B,d,1);
		for(int j=0;j<1<<d;j++) B[j]=B[j]*A[j]%mod;
		NTT(B,d,-1);
		for(int j=0;j<=i*(k-1);j++) f[i][j]=B[j]*fac[j]%mod;
	}
	
	for(int i=1;i<=n;i++) {
		ll w=ksm(i,mod-2)*n%mod;
		ll t=w;
		for(int s=0;s<=i*(k-1);s++) {
			(g[i]+=f[i][s]*t)%=mod;
			t=t*w%mod;
		}
	}
	ll ans=0;
	ll flag=1;
	for(int c=1;c<=n;c++,flag=flag*(mod-1)%mod) {
		(ans+=flag*C(n,c)%mod*g[c])%=mod;
	}
	cout<<ans;
	return 0;
}
posted @ 2019-04-21 15:11  hec0411  阅读(522)  评论(0编辑  收藏  举报