题解 百鸽笼

传送门

又是神仙DP,参考了这里

先把题意转化为有 \(n+1\) 个管理员要放,求每个鸽笼最后一个放满的概率
发现直接做的话每列放满之后放到其它列的概率会发生变化,很难做

  • 处理这种形如「向一些格子里放数,每个格子放满后不能再放(即一个格子放满后下一个球放到每个格子的概率会变化)」的情况:
    对于每个格子,容斥计算有其它格子比它先/后放满的概率
    具体地,可以状压枚举有哪些格子比它先放满
    对于一个固定的集合,考虑每个球的放入序列
    先假设每个格子可以放 \(a_i\) 个球
    这个序列应当包含 \(a_i\)\(i\)(第 \(i\) 个格子是被计算概率的格子),少于 \(a_j\)\(j\)(第 \(j\) 个格子在选定先放满的 \(m\) 个格子中),且以 \(i\) 结尾,且若其长度为 \(L\) ,则产生 \((\frac{1}{m+1})^L\) 的概率贡献
    于是可以背包计算
    然后这个东西还可以优化
    发现容斥的时候一种情况的系数仅与选了几个格子及序列长度有关
    于是背包的时候可以加一维记录用了几个格子
    “加一维状态表示当前选定的有关格子数量,枚举这个格子要不要被选定与如果要选定的话用掉多少个”
    于是计算一个格子复杂度是 \(O(n^3m^2)\)
    如果要计算所有格子的答案就还有一个优化
    发现复杂度瓶颈在于对剩下的 \(n-1\) 个做背包
    于是可以用退背包来优化
    对所有格子计算答案是 \(O(n^3m^2)\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int a[N];
ll fac[N], inv[N], inv2[N];
const ll mod=998244353;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll inv[3629000], ans[N];
	void dfs(int u, ll pre) {
		if (u>=m) {
			for (int i=1; i<=n; ++i) if (a[i]) md(ans[i], pre);
			return ;
		}
		int cnt=0;
		for (int i=1; i<=n; ++i) if (a[i]) ++cnt;
		for (int i=1; i<=n; ++i) if (a[i]) {
			--a[i];
			dfs(u+1, pre*inv[cnt]%mod);
			++a[i];
		}
	}
	void solve() {
		inv[0]=inv[1]=1;
		for (int i=2; i<3629000; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		dfs(1, 1);
		for (int i=1; i<=n; ++i) printf("%lld%c", ans[i], " \n"[i==n]);
		exit(0);
	}
}

namespace task1{
	ll ans[50], dp[N];
	int siz[1<<11];
	void solve() {
		int lim=1<<n;
		for (int s=0; s<lim; ++s) {for (int i=0; i<n; ++i) if (s&(1<<i)) siz[s]+=a[i+1]-1; ++siz[s];}
		for (int i=1; i<=n; ++i) {
			// cout<<"i: "<<i<<endl;
			for (int s=0,s2,cnt; s<lim; ++s) if (s&(1<<(i-1))) {
				// cout<<"s: "<<bitset<5>(s)<<endl;
				s2=s; cnt=0;
				if (s) do {s2&=s2-1; ++cnt;} while (s2);
				memset(dp, 0, sizeof(dp));
				dp[a[i]]=inv2[a[i]-1];
				for (int j=0; j<n; ++j) if (s&(1<<j) && j+1!=i) {
					for (int k=siz[s]; k>=a[i]+1; --k) {
						for (int h=1; h<a[j+1]; ++h) {
							dp[k]=(dp[k]+dp[k-h]*inv2[h])%mod;
						}
					}
				}
				// cout<<"siz: "<<siz[s]<<endl;
				// cout<<"dp: "; for (int i=1; i<=siz[s]; ++i) cout<<dp[i]<<' '; cout<<endl;
				ll tem=0, p=1;
				for (int j=1; j<=siz[s]; ++j) p=p*inv[cnt]%mod, tem=(tem+dp[j]*fac[j-1]%mod*p%mod)%mod;
				// tem=(tem+dp[siz[s]]*fac[siz[s]-1]%mod*qpow(inv[cnt], siz[s])%mod)%mod;
				// cout<<"dp2: "<<dp[2]*fac[2]%mod*inv[cnt]%mod*inv[cnt]%mod<<endl;
				// cout<<"tem: "<<tem<<endl;
				ans[i]=(ans[i]+((cnt&1)?1:-1)*tem)%mod;
			}
			printf("%lld%c", (ans[i]%mod+mod)%mod, " \n"[i==n]);
		}
		exit(0);
	}
}

namespace task2{
	ll ans[50], dp[35][N];
	void solve() {
		int lim=1;
		for (int j=1; j<=n; ++j) lim+=a[j]-1;
		for (int i=1; i<=n; ++i) {
			// cout<<"i: "<<i<<endl;
			memset(dp, 0, sizeof(dp));
			dp[1][a[i]]=inv2[a[i]-1];
			for (int j=1; j<=n; ++j) if (i!=j) {
				for (int k=n; k>1; --k) {
					for (int h=lim; h>=a[i]; --h) {
						for (int t=0; t<a[j]&&t<=h; ++t) {
							dp[k][h]=(dp[k][h]+dp[k-1][h-t]*inv2[t])%mod;
						}
					}
				}
			}
			// cout<<"---dp---"<<endl;
			// for (int j=1; j<=n; ++j) for (int k=1; k<=lim; ++k) printf("dp[%d][%d]=%lld%c", j, k, dp[j][k], " \n"[k==lim]);
			for (int j=1; j<=n; ++j) {
				ll tem=0, p=1;
				for (int k=1; k<=lim; ++k) p=p*inv[j]%mod, tem=(tem+dp[j][k]*fac[k-1]%mod*p%mod)%mod;
				ans[i]=(ans[i]+((j&1)?1:-1)*tem)%mod;
			}
			printf("%lld%c", (ans[i]%mod+mod)%mod, " \n"[i==n]);
		}
		exit(0);
	}
}

namespace task3{
	ll ans[50], f[35][N], g[35][N];
	void solve() {
		int lim=1;
		for (int j=1; j<=n; ++j) lim+=a[j]-1;
		for (int i=1; i<=n; ++i) {
			// cout<<"i: "<<i<<endl;
			memset(f, 0, sizeof(f));
			memset(g, 0, sizeof(g));
			f[0][0]=1;
			for (int j=1; j<=n; ++j) if (i!=j) {
				for (int k=n; k>=1; --k) {
					for (int h=lim; ~h; --h) {
						for (int t=0; t<a[j]&&t<=h; ++t) {
							f[k][h]=(f[k][h]+f[k-1][h-t]*inv2[t])%mod;
						}
					}
				}
			}
			for (int k=n; k>=1; --k) {
				for (int h=lim; h>=a[i]; --h) {
					g[k][h]=(g[k][h]+f[k-1][h-a[i]]*inv2[a[i]-1])%mod;
				}
			}
			// cout<<"---f---"<<endl;
			// for (int j=1; j<=n; ++j) for (int k=1; k<=lim; ++k) printf("f[%d][%d]=%lld%c", j, k, f[j][k], " \n"[k==lim]);
			for (int j=1; j<=n; ++j) {
				ll tem=0, p=1;
				for (int k=1; k<=lim; ++k) p=p*inv[j]%mod, tem=(tem+g[j][k]*fac[k-1]%mod*p%mod)%mod;
				ans[i]=(ans[i]+((j&1)?1:-1)*tem)%mod;
			}
			printf("%lld%c", (ans[i]%mod+mod)%mod, " \n"[i==n]);
		}
		exit(0);
	}
}

namespace task{
	ll ans[50], f[35][N], g[35][N], s[35][N];
	void solve() {
		int lim=1;
		for (int j=1; j<=n; ++j) lim+=a[j]-1;
		s[0][0]=1;
		for (int j=1; j<=n; ++j) {
			for (int k=n; k>=1; --k) {
				for (int h=lim; ~h; --h) {
					for (int t=0; t<a[j]&&t<=h; ++t) {
						s[k][h]=(s[k][h]+s[k-1][h-t]*inv2[t])%mod;
					}
				}
			}
		}
		for (int i=1; i<=n; ++i) {
			// cout<<"i: "<<i<<endl;
			for (int j=0; j<=n; ++j) for (int k=0; k<=lim; ++k) f[j][k]=s[j][k];
			memset(g, 0, sizeof(g));
			for (int k=1; k<=n; ++k) {
				for (int h=0; h<=lim; ++h) {
					for (int t=0; t<a[i]&&t<=h; ++t) {
						f[k][h]=(f[k][h]-f[k-1][h-t]*inv2[t])%mod;
					}
				}
			}
			for (int k=n; k>=1; --k) {
				for (int h=lim; h>=a[i]; --h) {
					g[k][h]=(g[k][h]+f[k-1][h-a[i]]*inv2[a[i]-1])%mod;
				}
			}
			// cout<<"---f---"<<endl;
			// for (int j=1; j<=n; ++j) for (int k=1; k<=lim; ++k) printf("f[%d][%d]=%lld%c", j, k, f[j][k], " \n"[k==lim]);
			for (int j=1; j<=n; ++j) {
				ll tem=0, p=1;
				for (int k=1; k<=lim; ++k) p=p*inv[j]%mod, tem=(tem+g[j][k]*fac[k-1]%mod*p%mod)%mod;
				ans[i]=(ans[i]+((j&1)?1:-1)*tem)%mod;
			}
			printf("%lld%c", (ans[i]%mod+mod)%mod, " \n"[i==n]);
		}
		exit(0);
	}
}

signed main()
{
	freopen("c.in", "r", stdin);
	freopen("c.out", "w", stdout);

	n=read();
	for (int i=1; i<=n; ++i) a[i]=read(), m+=a[i];
	fac[0]=fac[1]=1; inv[0]=inv[1]=1; inv2[0]=inv2[1]=1;
	for (int i=2; i<N; ++i) fac[i]=fac[i-1]*i%mod;
	for (int i=2; i<N; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for (int i=2; i<N; ++i) inv2[i]=inv2[i-1]*inv[i]%mod;
	// cout<<"inv: "; for (int i=1; i<=10; ++i) cout<<inv[i]<<' '; cout<<endl;
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2021-10-28 09:33  Administrator-09  阅读(2)  评论(0编辑  收藏  举报