【PR #3】抽卡(DP)

题面:

你在玩一个抽卡游戏。

这个游戏有 \(n+1\) 种级别的抽卡方式,编号为 \(0,1,\cdots,n\) 。抽出来的每张卡的等级是 \([0,m]\) 中的一个整数。

一次 0 级抽卡就是只抽一次卡,而一次 \(i\) 级抽卡 \((1\le i\le n)\) 会包含 \(b_i\)\(i-1\) 级抽卡,并且这次 \(i\) 级抽卡合法当且仅当它包含的所有 \(i-1\) 级抽卡合法,且抽出来的卡中至少有一张的等级大于等于 \(i\)

对于每次 0 级抽卡,抽出一张等级为 \(j\) 的卡的概率是 \(\dfrac{a_j}{\sum_{k=0}^m a_k}\)

\(p_j\) 表示在一次合法\(n\) 级抽卡中抽出等级为 \(j\) 的卡的期望次数,\(q\) 表示一次 \(n\) 级抽卡合法的概率。你需要对于 \(0\le j\le m\) 求出 \((p_j\cdot q)\bmod {998244353}\)

\(n,m\leq 4000\)

题解:

\(f_{i,j}\) 表示 \(i\) 级抽卡抽出来合法且最大值 \(=j\) 的概率。那么:

\[f_{i,j}=\left(\sum_{k\leq j}f_{i-1,k}\right)^{b_i}-\left(\sum_{k<j}f_{i-1,k}\right)^{b_i} \]

然后每次将 \(f_{i,i-1}\)\(0\)。用前缀和优化即可 \(O(n^2\log b)\) 解决第二问。

现在考虑抽出来等级为 \(j\) 的卡的期望数量。这里比较巧妙的思路是,我们考虑最后一层中,某张等级为 \(j\) 的卡能一直顺利到达顶层的概率。

\(g_{i,j}\) 表示 \(i\) 级抽卡抽出来最大值为 \(j\),接下来能顺利到达顶层的概率。考虑 \(i+1\) 级的最大值,有:

\[g_{i,j}=g_{i+1,j}\left(\sum_{k\leq j}f_{i,k}\right)^{b_{i+1}-1}+\sum_{v>j}g_{i+1,v}\left(\left(\sum_{k\leq v}f_{i,k}\right)^{b_{i+1}-1}-\left(\sum_{k<v}f_{i,k}\right)^{b_{i+1}-1}\right) \]

同样使用前缀/后缀和优化也可做到 \(O(n^2\log b)\)

#include<bits/stdc++.h>

#define N 4010

using namespace std;

namespace modular
{
	const int mod=998244353;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
	inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
	inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
	inline void Mul(int &x,int y){x=1ll*x*y%mod;}
	inline int poww(int a,int b){int ans=1;for(;b;Mul(a,a),b>>=1)if(b&1)Mul(ans,a);return ans;}
}using namespace modular;

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

int m,n,a[N],b[N];
int f[N][N],sf[N][N],g[N][N];

int main()
{
	m=read(),n=read();
	int sa=0;
	for(int i=0;i<=m;i++) a[i]=read(),Add(sa,a[i]);
	sa=poww(sa,mod-2);
	for(int i=0;i<=m;i++) Mul(a[i],sa);
	for(int i=1;i<=n;i++) b[i]=read();
	for(int j=0;j<=m;j++)
		f[0][j]=a[j],sf[0][j]=add(j?sf[0][j-1]:0,f[0][j]);
	for(int i=1;i<=n;i++)
	{
		for(int j=i;j<=m;j++)
		{
			f[i][j]=dec(poww(sf[i-1][j],b[i]),poww(sf[i-1][j-1],b[i]));
			sf[i][j]=add(sf[i][j-1],f[i][j]);
		}
	}
	for(int j=n;j<=m;j++) g[n][j]=1;
	for(int i=n-1;i>=0;i--)
	{
		int s=0;
		for(int j=m;j>=i;j--)
		{
			g[i][j]=add(mul(g[i+1][j],poww(sf[i][j],b[i+1]-1)),s);
			if(j) Add(s,mul(g[i+1][j],dec(poww(sf[i][j],b[i+1]-1),poww(sf[i][j-1],b[i+1]-1))));
		}
	}
	int prod=1;
	for(int i=1;i<=n;i++) Mul(prod,b[i]);
	for(int j=0;j<=m;j++)
		printf("%d\n",mul(mul(prod,a[j]),g[0][j]));
	return 0;
}
posted @ 2022-11-04 11:11  ez_lcw  阅读(145)  评论(0编辑  收藏  举报