【PR #3】抽卡(DP)
题面:
你在玩一个抽卡游戏。
这个游戏有 \(n+1\) 种级别的抽卡方式,编号为 \(0,1,\cdots,n\) 。抽出来的每张卡的等级是 \([0,m]\) 中的一个整数。
一次 0 级抽卡就是只抽一次卡,而一次 \(i\) 级抽卡 \((1\le i\le n)\) 会包含 \(b_i\) 次 \(i-1\) 级抽卡,并且这次 \(i\) 级抽卡合法当且仅当它包含的所有 \(i-1\) 级抽卡合法,且抽出来的卡中至少有一张的等级大于等于 \(i\) 。
对于每次 0 级抽卡,抽出一张等级为 \(j\) 的卡的概率是 \(\dfrac{a_j}{\sum_{k=0}^m a_k}\) 。
设 \(p_j\) 表示在一次合法的 \(n\) 级抽卡中抽出等级为 \(j\) 的卡的期望次数,\(q\) 表示一次 \(n\) 级抽卡合法的概率。你需要对于 \(0\le j\le m\) 求出 \((p_j\cdot q)\bmod {998244353}\) 。
\(n,m\leq 4000\)。
题解:
设 \(f_{i,j}\) 表示 \(i\) 级抽卡抽出来合法且最大值 \(=j\) 的概率。那么:
\[f_{i,j}=\left(\sum_{k\leq j}f_{i-1,k}\right)^{b_i}-\left(\sum_{k<j}f_{i-1,k}\right)^{b_i}
\]
然后每次将 \(f_{i,i-1}\) 设 \(0\)。用前缀和优化即可 \(O(n^2\log b)\) 解决第二问。
现在考虑抽出来等级为 \(j\) 的卡的期望数量。这里比较巧妙的思路是,我们考虑最后一层中,某张等级为 \(j\) 的卡能一直顺利到达顶层的概率。
设 \(g_{i,j}\) 表示 \(i\) 级抽卡抽出来最大值为 \(j\),接下来能顺利到达顶层的概率。考虑 \(i+1\) 级的最大值,有:
\[g_{i,j}=g_{i+1,j}\left(\sum_{k\leq j}f_{i,k}\right)^{b_{i+1}-1}+\sum_{v>j}g_{i+1,v}\left(\left(\sum_{k\leq v}f_{i,k}\right)^{b_{i+1}-1}-\left(\sum_{k<v}f_{i,k}\right)^{b_{i+1}-1}\right)
\]
同样使用前缀/后缀和优化也可做到 \(O(n^2\log b)\)。
#include<bits/stdc++.h>
#define N 4010
using namespace std;
namespace modular
{
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
inline void Dec(int &x,int y){x=x-y<0?x-y+mod:x-y;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
inline int poww(int a,int b){int ans=1;for(;b;Mul(a,a),b>>=1)if(b&1)Mul(ans,a);return ans;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
int m,n,a[N],b[N];
int f[N][N],sf[N][N],g[N][N];
int main()
{
m=read(),n=read();
int sa=0;
for(int i=0;i<=m;i++) a[i]=read(),Add(sa,a[i]);
sa=poww(sa,mod-2);
for(int i=0;i<=m;i++) Mul(a[i],sa);
for(int i=1;i<=n;i++) b[i]=read();
for(int j=0;j<=m;j++)
f[0][j]=a[j],sf[0][j]=add(j?sf[0][j-1]:0,f[0][j]);
for(int i=1;i<=n;i++)
{
for(int j=i;j<=m;j++)
{
f[i][j]=dec(poww(sf[i-1][j],b[i]),poww(sf[i-1][j-1],b[i]));
sf[i][j]=add(sf[i][j-1],f[i][j]);
}
}
for(int j=n;j<=m;j++) g[n][j]=1;
for(int i=n-1;i>=0;i--)
{
int s=0;
for(int j=m;j>=i;j--)
{
g[i][j]=add(mul(g[i+1][j],poww(sf[i][j],b[i+1]-1)),s);
if(j) Add(s,mul(g[i+1][j],dec(poww(sf[i][j],b[i+1]-1),poww(sf[i][j-1],b[i+1]-1))));
}
}
int prod=1;
for(int i=1;i<=n;i++) Mul(prod,b[i]);
for(int j=0;j<=m;j++)
printf("%d\n",mul(mul(prod,a[j]),g[0][j]));
return 0;
}