P5405 [CTS2019]氪金手游

没想到最后一步。还是太菜了。

简要题意

  • \(n\) 种卡牌,每一轮抽到第 \(i\) 种卡牌的概率为 \(\dfrac{w_i}{\sum w_j}\),其中 \(\forall j\in\{1,2,3\}\)\(w_i\)\(p_{i,j}\) 的概率为 \(j\)

  • 设第一次抽到 \(i\) 的时间为 \(T_i\),有 \(n-1\) 对限制构成一棵有向树,满足对于所有树边 \(u_i\to v_i\)\(T_{u_i} \lt T_{v_i}\)。求满足所有限制的概率。

  • \(n \le 1000\)

题解

若钦定树根,这棵树有着纷繁错杂的连边关系:根向、叶向可以交替出现,这就为解题带来很大不便。为此,先考虑所有边方向一致(不妨假定为根向,叶向同理)。

需要注意的是,每种卡牌出现的概率并不独立,这就意味着不能用简单的期望代替 \(w_i\)。由于边是根向的,题目中的限制等价于:任意结点先于子树中的结点出现(记为条件 A)。而当子树中的 \(w_i\) 确定时,条件 A 成立的概率是确定的:\(P(A) = \dfrac{w_u}{\sum\limits_{v\in \text{subtree}\ u}w_v}\)。因此只需记录子树的 \(w_v\) 之和即可 dp:\(f_{u,i}\) 表示以 \(u\) 为根的子树中,\(w_v\) 之和为 \(i\) 且在该子树中满足所有限制的概率。则: \(f^{'}_{u,i+j} = \dfrac{1}{i+j}\sum_{i,j}f_{u,i}\cdot f_{v,j}\),初值 \(f_{u,j} = p_{u,j}\cdot j(j \le 3)\)

当加入反向边之后,情况就不再这么简单——直接做会导致系数乱套。考虑一个容斥:\(P(反向) = P(不存在)-P(正向)\) 。于是在树边上记一个容斥系数即可 dp。

Code

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int MAXN = 1010;
const int MOD = 998244353;

inline int add(int a, int b) {return (a+=b)>=MOD?a-MOD:a;}
inline void inc(int& a, int b) {a = add(a, b);}
inline int sub(int a, int b) {return (a-=b)<0?a+MOD:a;}
inline void dec(int& a, int b) {a = sub(a, b);}
inline int mul(int a, int b) {return 1ll*a*b%MOD;}
inline void mlt(int& a, int b) {a = mul(a, b);}
inline int pw(int x, int p = MOD-2)
{
	int res = 1;
	for(;p;p>>=1,mlt(x,x)) if(p&1) mlt(res, x);
	return res;
}
int inv[MAXN*3];
inline void initInv(int n)
{
	inv[1] = 1;
	for(int i=2;i<=n;i++) inv[i] = mul(MOD-MOD/i, inv[MOD%i]);
}

int n, p[MAXN][4];
struct edge{
	int ne, to, w;
	edge(int N=0,int T=0,int W=0):ne(N),to(T),w(W){}
}e[MAXN<<1];
int fir[MAXN], num = 0;
inline void join(int a, int b, int c)
{
	e[++num] = edge(fir[a], b, c);
	fir[a] = num;
}
int f[MAXN][MAXN*3], siz[MAXN], tmp[MAXN*3];
void dfs(int u, int fa)
{
	for(int i=1;i<=3;i++) f[u][i] = mul(p[u][i], i);
	siz[u] = 3;
	for(int i=fir[u];i;i=e[i].ne)
	{
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v, u);
		memset(tmp, 0, (siz[u]+siz[v]+2)<<2);
		for(int j=1;j<=siz[u];j++)
			for(int k=1;k<=siz[v];k++)
			{
				inc(tmp[j+k], mul(mul(f[u][j], f[v][k]), e[i].w));
				if(e[i].w != 1) inc(tmp[j], mul(f[u][j], f[v][k])); 
			}
		siz[u] += siz[v];
		for(int j=1;j<=siz[u];j++) f[u][j] = tmp[j];
	}
	for(int i=1;i<=siz[u];i++) mlt(f[u][i], inv[i]);
}

inline void work()
{
	scanf("%d",&n);
	initInv(n*3);
	for(int i=1;i<=n;i++)
	{
		int s = 0;
		for(int j=1;j<=3;j++) scanf("%d",p[i]+j), inc(s, p[i][j]);
		s = pw(s);
		for(int j=1;j<=3;j++) mlt(p[i][j], s);
	}
	for(int i=1,u,v;i<n;i++)
	{
		scanf("%d%d",&u,&v);
		join(u, v, 1);
		join(v, u, MOD-1);
	}
	dfs(1, 0);
	int ans = 0;
	for(int i=1;i<=n*3;i++) inc(ans, f[1][i]);
	printf("%d\n",ans);
}

int main()
{
	int T = 1;
//	scanf("%d%d",&T,&MOD);
//	prework(N);
	while(T--) work();
	return 0;
}
posted @ 2021-10-10 18:21  Aphrosia  阅读(33)  评论(0编辑  收藏  举报