P5643 [PKUWC2018]随机游走 min-max容斥+FWT

直接求不好求,我们考虑 \(min-max\) 容斥:\(\displaystyle E(max(S))=\sum_{T \subseteq S}(-1)^{|T|+1}E(min(T))\)

其中 \(S\) 为到达相应的点花费时间的集合, \(max(S)\) 为到过所有点的时间, \(min(S)\) 为到过一个点的时间。

然后就变成了给定一个集合 \(S\) ,求 \(min(S)\) .

我们考虑 \(DP\) ,设 \(f[i]\) 为从 \(i\) 点开始,到过 \(S\) 中一个点的期望时间。

\(i \in S\) ,则 \(f[i]=0\)

否则 \(\displaystyle f[i]=\frac{f[Fa[i]]+\sum f[son]}{du[i]}+1\)

这时我们可以暴力高斯消元了,但这里有个小技巧:树上路径期望问题可以把每个节点的 \(dp\) 值表示 \(a \times f[Fa[i]] + b\)的形式

然后就可以化简一下式子。

\(\displaystyle f[i]=\frac{f[Fa[i]]+\sum f[son]}{du[i]}+1\)

\(\displaystyle f[i]=\frac{f[Fa[i]]+suma \times f[i]+sumb}{du[i]}+1\)

其中 \(\displaystyle suma=\sum a[son] \space\space\space\space\space\space sumb=\sum b[son]\)

\(\displaystyle (du[i]-suma)f[i]=f[Fa[i]]+sumb+du[i]\)

\(\displaystyle f[i]=\frac{1}{du[i]-suma}f[Fa[i]]+\frac{sumb+du[i]}{du[i]-suma}\)

对比 \(a \times f[Fa[i]] + b\)
可得
\(\displaystyle a[i]=\frac{1}{du[i]-suma}\)\(\displaystyle b[i]=\frac{sumb+du[i]}{du[i]-suma}\)

我们就求得了\(E(min(T))\),用 \(FWT\) 快速求子集和即可

#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n, q, root, all, x, y, tot, k, s;
const int N = 19, mod = 998244353;
int head[N], to[N << 1], nt[N << 1], du[N], A[N], B[N], f[1 << 18 | 1];
void add(int f, int t)
{
	to[++tot] = t; nt[tot] = head[f]; head[f] = tot;
}
LL ksm(LL a, LL b, LL mod)
{
	LL res = 1; a %= mod;
	for (; b; b >>= 1, a = a * a % mod)
		if (b & 1)res = res * a % mod;
	return res;
}
void dfs(int x, int fa, int s)
{
	if (s & (1 << (x - 1)))return;
	int sumA = 0, sumB = 0;
	for (int i = head[x]; i; i = nt[i])
		if (to[i] != fa)
		{
			dfs(to[i], x, s);
			(sumA += A[to[i]]) %= mod;
			(sumB += B[to[i]]) %= mod;
		}
	int inv = ksm(du[x] - sumA, mod - 2, mod);
	A[x] = inv; B[x] = (LL)inv * (sumB + du[x]) % mod;
}
int pan(int x)
{
	int res = 0;
	while (x)res += (x & 1), x >>= 1;
	return res & 1 ? 1 : -1;
}
int main()
{
	cin >> n >> q >> root; all = 1 << n;
	for (int i = 1; i < n; ++i)
	{
		scanf("%d%d", &x, &y);
		add(x, y); add(y, x); ++du[x]; ++du[y];
	}
	for (int s = 1; s < all; ++s)
	{
		for (int i = 1; i <= n; ++i)A[i] = B[i] = 0;
		dfs(root, 0, s);
		f[s] = (pan(s) * B[root] + mod) % mod;
	}
	for (int mid = 1; mid < all; mid <<= 1)
		for (int j = 0, len = mid << 1; j < all; j += len)
			for (int k = j; k < j + mid; ++k)
				(f[k + mid] += f[k]) %= mod;
	while (q--)
	{
		scanf("%d", &k); s = 0;
		for (int i = 1; i <= k; ++i)scanf("%d", &x), s |= 1 << (x - 1);
		printf("%d\n", f[s]);
	}
}
posted @ 2020-04-02 19:04  wljss  阅读(228)  评论(0编辑  收藏  举报