P5405 [CTS2019]氪金手游 【数学概率+树形dp】

P5405 [CTS2019]氪金手游 【数学概率+树形dp】

先考虑外向树的情况:

这个的关键是要把求满足拓扑序的概率转化为求 每个点都比它的子树中的所有节点先取到的概率 。单个节点 x x x 的概率是独立的,为 w x ∑ y ∈ s u b t r e e ( x ) w y \frac{w_x}{\sum_{y\in subtree(x)}w_y} ysubtree(x)wywx ,答案就是所有情况下的 节点概率之积 的和。
f [ x ] f[x] f[x] 表示 x x x 子树内满足拓扑序的概率, s z [ x ] sz[x] sz[x] 表示 ∑ y ∈ s u b t r e e ( x ) w y \sum_{y\in subtree(x)}w_y ysubtree(x)wy
可以发现每个 f [ x ] f[x] f[x] 是只与 x x x 的子树有关, f [ x ] = w x s z [ x ] ∏ x → y f [ y ] f[x]=\frac{w_x}{sz[x]}\prod_{x \rightarrow y} f[y] f[x]=sz[x]wxxyf[y]
然后就加一维 f [ x ] [ s z [ x ] ] f[x][sz[x]] f[x][sz[x]] 就可以直接dp了。

考虑有内向边的情况:

可以用 这条边可以外向也可以内向 的方案数,减去 这条边一定外向 的方案数。
也可以容斥,记 g [ i ] g[i] g[i] 表示至少有 i i i 条边不满足条件的方案数。
i i i 条内向边变外向边,剩下的内向边就是 可以外向也可以内向 。
答案就是:

至少零个条件不满足 − - 至少一个条件不满足 + + + 至少两个条件不满足 − ⋯ -\cdots

时间复杂度 O ( n 2 ) O(n^2) O(n2)

#include <bits/stdc++.h>
#define N 1003
using namespace std;
typedef long long ll;
const int mod=998244353;
int head[N],nxt[N<<1],to[N<<1],tag[N<<1];
int sz[N],lst[N];
int a1[N],a2[N],a3[N];
ll F[N][N*3],arr[N*3];
ll inv[N*3],n,_;
ll ksm(ll x,ll y){
	ll res=1;
	while(y){ 
		if(y&1) res=res*x%mod;
		x=x*x%mod; y>>=1;
	}
	return res;
} 
void init(){
	inv[0]=inv[1]=1;
	for(int i=2;i<=n*3;i++) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
void add(int x,int y,int z){ nxt[++_]=head[x],head[x]=_,to[_]=y,tag[_]=z; }
void dfs(int x){
	ll *f=F[x]; f[0]=1;
	int y; 
	for(int __=head[x],y=to[__];__;__=nxt[__],y=to[__]){
		if(y==lst[x]) continue;
		lst[y]=x; dfs(y); ll *g=F[y];
		for(int i=0;i<=(sz[x]+sz[y])*3;i++) arr[i]=0;
		if(tag[__]){
			for(int i=0;i<=sz[x]*3;i++)
				for(int j=0;j<=sz[y]*3;j++)
					arr[i+j]=(arr[i+j]+f[i]*g[j]%mod)%mod;
			for(int i=0;i<=(sz[x]+sz[y])*3;i++) f[i]=arr[i];
		}
		else{
			ll now=0;
			for(int i=0;i<=sz[y]*3;i++) now=(now+g[i])%mod;
			for(int i=0;i<=sz[x]*3;i++) arr[i]=(f[i]*now)%mod;
			for(int i=0;i<=sz[x]*3;i++)
				for(int j=0;j<=sz[y]*3;j++)
					arr[i+j]=(arr[i+j]-f[i]*g[j]%mod+mod)%mod;
			for(int i=0;i<=(sz[x]+sz[y])*3;i++) f[i]=arr[i];
		}
		sz[x]+=sz[y];
	}
	for(int i=0;i<=sz[x]*3+3;i++) arr[i]=0;
	for(int i=0;i<=sz[x]*3;i++){
		arr[i+1]=(arr[i+1]+f[i]*a1[x]%mod*inv[i+1]%mod)%mod;
		arr[i+2]=(arr[i+2]+f[i]*a2[x]%mod*inv[i+2]*2ll%mod)%mod;
		arr[i+3]=(arr[i+3]+f[i]*a3[x]%mod*inv[i+3]*3ll%mod)%mod;
	}
	sz[x]++;
	for(int i=0;i<=sz[x]*3;i++) f[i]=arr[i];
//	cout<<'\n'; 
}
int main(){
//	freopen("fgo9.in","r",stdin);
	cin>>n;
	init();
	for(int i=1;i<=n;i++){
		cin>>a1[i]>>a2[i]>>a3[i];
		ll now=ksm(a1[i]+a2[i]+a3[i],mod-2);
		a1[i]=a1[i]*now%mod;
		a2[i]=a2[i]*now%mod;
		a3[i]=a3[i]*now%mod;
//		cout<<a1[i]<<' '<<a2[i]<<' '<<a3[i]<<'\n';
	}
	int u,v;
	for(int i=1;i<n;i++) cin>>u>>v,add(u,v,1),add(v,u,0);
	dfs(1);
	ll ans=0;
	for(int i=0;i<=sz[1]*3;i++) (ans+=F[1][i])%=mod;
	cout<<ans;
}
posted @ 2022-10-10 20:18  缙云山车神  阅读(19)  评论(0编辑  收藏  举报