某个概率题的一、拓展

题目链接:[https://codeforces.com/gym/104053/problem/I]

有很简单的背包做法,但是本人赛后想了很久一些关于 \(\times 0\) 怎么求逆之类的(无聊问题),本文主要讨论了一下基于生成函数的卷积

以下设初始得病的概率是 \(\alpha_u\),被相邻的点感染的概率是 \(p_u\)
考虑假设初始的时候哪个点被感染是确定的,设 \(u\) 的生成函数为 \(F_u(x)\),\([x^k]\) 表示 \(u\) 的子树中感染了 \(k\) 个点的概率

我们考虑转移 \(F_u(x) = (p_ux \times \prod_{v \in son(u)} F_v(x)) + (1 - p_u)\)

显然结果是一个关于 \(n\) 的多项式

但如果我们依次枚举哪个点初始时被感染,时间复杂度将达到难以接受的 \(O(n^3)\),一个比较显然的想法是考虑换根 dp ,然而这些多项式并不一定有逆(可能常数项为 0 )

不过当然有一些神奇的维护前缀后缀的方法,此处暂时不说

我们考虑多一个占位符 \(y\),来表示是否存在初始时被感染的点,那么 \([y]\) 的结果(一个 \(n\)次多项式即我们想要的答案)

考虑构造一个二元的多项式环 \((\Z(x,y)\bmod y^2)\)

那么转移变成了 \(F_u(x,y) = ((p_u x + \alpha_u xy) \times \prod_{v \in son(u)}F_v(x)) + (1 - p_u)\)

考虑最后的结果为 \(y * P_1(x) + P_2(x)\),并注意到我们只关心 \(P_1(x)\),并且 \(P_1(x)\)是一个至多 n 次多项式

考虑枚举 \(x = \{1,2,3,4,....n,n+1\}\),转移的时候就只相当于在一个 \((\bmod y^2)\) 的一元多项式环上做乘法,这一步是 \(O(n^2)\)

然后,我们可以得到一个 \(P_1(x)\) 的点值表示,插值把系数插出来即可,这一步也是 \(O(n^2)\)

\(ans_k\)\(P_1(x)[x^k] = F_{rt}(x,y)[x^k][y^1] + \sum_{u != rt}F_{u}(x,y)[x^k][y^1] * (1 - p_{fa_u})\)

#include<bits/stdc++.h>
using namespace std;
int read() {
	char c = getchar();
	int x = 0;
	while (c < '0' || c > '9')		c = getchar();
	while (c >= '0' && c <= '9')	x = x * 10 + c - 48,c = getchar();
	return x;
}
const int MaxN = 2e3 + 7;
const int mod = 1e9 + 7;
vector<int>E[MaxN];
int a[MaxN],Sum = 0;
int qpow(int x,int y) {
	int ans = 1;
	while (y) {
		if (y & 1)	ans = 1ll * ans * x % mod;
		x = 1ll * x * x % mod;
		y >>= 1;
	}
	return ans;
}
void add(int &x,int y) {
	x += y - mod;
	x += (x >> 31) & mod;
}
void sub(int &x,int y) {
	y = (mod - y);
	add(x,y);
}
int p[MaxN],b[MaxN],c[MaxN];
int f[MaxN];
int siz[MaxN];
int F[MaxN],ans[MaxN];
int n;
#define pb push_back
struct Z {
	int x,y;
}dp[MaxN];
Z operator +(Z A,Z B) {
	return Z{(A.x + B.x) % mod,(A.y + B.y) % mod};
}
Z operator *(Z A,Z B) {
	return Z{(int)(1ll * A.x * B.x % mod),(int)((1ll * A.x * B.y + 1ll * A.y * B.x) % mod)};
}
Z operator -(Z A,Z B) {
	B = Z{mod - B.x,mod - B.y};
	return A + B;
}
Z operator *(int A,Z B) {
	return Z{A,0} * B;
}
Z operator *(Z A,int B) {
	return B * A;
}
void treedp(int u,int fa,const int x) {
	dp[u] = Z{1,0};
	Z G = Z{(int)(1ll * p[u] * x % mod),(int)(1ll * a[u] * x % mod)};
	for (auto v:E[u]) {
		if (v ^ fa) {
			treedp(v,u,x);
			dp[u] = dp[u] * dp[v];		
		}
	}		
	dp[u] = dp[u] * G + Z{(1 - p[u] + mod) % mod,0};	
	if (fa == 0) {
		add(F[x-1],dp[u].y);
	}
	else {
		add(F[x-1],1ll * dp[u].y * (1 - p[fa] + mod) % mod);
	}
}
void largrange(int *y,int *A,int *x,int n) {
	int lambda = 1;
	static int P[MaxN],Pr[MaxN],G[MaxN];
	memset(P,0,sizeof(P));
	memset(Pr,0,sizeof(Pr));
	memset(G,0,sizeof(G));
	G[0] = 1;
	for (int i = 0; i <= n; ++i) {
		for (int j = 0; j <= i + 1; ++j) {
			P[j] = G[j];
			G[j] = 0;
		}
		for (int j = 0; j <= i; ++j) {
			add(G[j+1],P[j]);
			sub(G[j],1ll * P[j] * x[i] % mod);
		}
	}
	for (int i = 0; i <= n; ++i) {
		lambda = 1;
		for (int j = 0; j <= n; ++j)
			if (i != j)	lambda = (1ll * lambda * (x[i] - x[j] + mod) % mod) % mod;
		lambda = 1ll * qpow(lambda,mod - 2) * y[i] % mod;
		for (int j = 0; j <= n + 1; ++j)	P[j] = G[j];/*deg(G[j]) = n + 1*/
		for (int j = n + 1; j >= 1; --j) {
			Pr[j-1] = P[j];
			add(P[j-1],1ll * Pr[j-1] * x[i] % mod);
		}
		for (int j = 0; j <= n; ++j)	add(A[j],1ll * lambda * Pr[j] % mod);
	}
}
int main() {
	n = read();
	for (int i = 1; i < n; ++i) {
		int u = read(),v = read();
		E[u].pb(v);E[v].pb(u);
	}
	for (int i = 1; i <= n; ++i) {
		a[i] = read(),b[i] = read(),c[i] = read();
		add(Sum,a[i]);p[i] = 1ll * b[i] * qpow(c[i],mod - 2) % mod;
	}
	for (int i = 1; i <= n; ++i)	a[i] = 1ll * a[i] * qpow(Sum,mod - 2) % mod;
	for (int x = 1; x <= n + 1; ++x) {
		treedp(1,0,x);
	}
	static int id[MaxN];
	for (int i = 1; i <= n + 1; ++i)	id[i-1] = i;
	largrange(F,ans,id,n);
	for (int i = 1; i <= n; ++i) {
		cout << ans[i] << '\n';
	}
	return 0;
}
posted @ 2022-11-17 15:52  y_dove  阅读(116)  评论(0编辑  收藏  举报