Live2D

Solution -「GLR-R2」教材运送

\(\mathcal{Description}\)

  Link.

  给定一棵包含 \(n\) 个点,有点权和边权的树。设当前位置 \(s\)(初始时 \(s=1\)),每次在 \(n\) 个结点内随机选择目标结点 \(t\),付出「\(s\)\(t\) 的简单路径上的边权之和」\(\times\)\(t\) 的点权」的代价,标记(可以重复标记)点 \(t\) 并把 \(s\) 置为 \(t\)。求每个点至少被标记一次时(其中 \(1\) 号结点一开始就被标记)代价之和的期望。答案对 \(998244353\) 取模。

  \(n\le10^6\)

\(\mathcal{Solution}\)

  首先,有期望意义下的 Min-Max 容斥的公式:

\[E(\max(S))=\sum_{T\subseteq S\land T\not=\varnothing}(-1)^{|T|-1}E(\min(T)) \]

  对于本题,\(S=\{2,3,\cdots,n\}\),要求的答案等价于标记最后一个未标记点时代价的期望,那么枚举上式中的 \(T\),并设 \(|T|=m\),我们只需要对于每个 \(T\),求出从结点 \(1\) 出发,标记 \(T\) 集合内任意一个点的期望代价就行。

  考虑一个朴素的 DP:令 \(f_T(u)\) 表示现在在 \(u\) 点(\(u\) 点已标记)时,标记 \(T\) 内任意一点的期望代价。显然:

\[f_T(u)=\begin{cases}0&u\in T\\ \frac{1}n\left( \sum_{v=1}^nf_T(v)+\operatorname{dist}(u,v) \right)&u\not\in T \end{cases} \]

  其中 \(\operatorname{dist}(u,v)\) 即表示题意中把 \(u\) 置为 \(v\) 的代价。到此,你就可以获得 \(10\) 分的好成绩啦!


  接下来,取出一个 \(u\not\in T\)\(f_T(u)\) 来研究:

\[f_T(u)=\frac{1}n\left( \sum_{v=1}^nf_T(v)+\operatorname{dist}(u,v) \right)\\ \Rightarrow nf_T(u)-\sum_{v=1}^nf_T(v)=\sum_{v=1}^n\operatorname{dist}(u,v) \]

  令 \(s=\sum_{v\not\in T}f(v)\)\(w(u)=\sum_{v=1}^n\operatorname{dist}(u,v)\)\(C=S-T=\{c_1,c_2,\cdots,c_{n-m}\}\)(其中恒有 \(c_1=1\))列出共 \(|C|\) 个等式:

\[\left. \begin{matrix} nf_T(c_1)-s=w(c_1)\\ nf_T(c_2)-s=w(c_2)\\ \cdots\\ nf_T(c_{n-m})-s=w(c_{n-m}) \end{matrix} \right\}n-m\text{ equations in total.} \]

  左右分别相加得到:

\[ns-(n-m)s=\sum_{v\not\in T}w(v)\\ \Rightarrow s=\frac{\sum_{v\not\in T}w(v)}{m} \]

  于是乎,要求的 \(f_T(1)\) 就有:

\[f_T(c_1)=\frac{s+w(1)}{n} \]


  此后,把 \(f_T(1)\) 带入答案的式子里:

\[\begin{aligned} &\sum_{T\subseteq S\land T\not=\varnothing}(-1)^{|T|-1}f_T(1)\\ =&\sum_{T\subseteq S\land T\not=\varnothing}(-1)^{|T|-1}\left( \frac{w(1)}{n}+\frac{\sum_{u\not\in T}w(u)}{n|T|} \right)\\ =&\frac{1}n\sum_{m=1}^{n-1}(-1)^{m-1}\left[ \binom{n-1}mw(1)+\frac{1}m\sum_{T\subseteq S\land|T|=m}\sum_{u\not\in T}w(u) \right] \end{aligned} \]

  令 \(g(m)=\sum_{T\subseteq S\land|T|=m}\sum_{v\not\in T}w(v)\),单独考虑结点 \(1\),它必然不属于 \(T\);再考虑其他结点的贡献次数,可以得出:

\[g(m)=\binom{n-1}{m}w(1)+\binom{n-2}m\sum_{u=2}^nw(u) \]

  最后,只需要求出 \(\sum_{u=1}^nw(u)\)。分别考虑每条边 \((u,v,b)\in E\) 的贡献。将这条边删去,记此时 \(u\) 所在联通块的结点个数为 \(p_u\),结点点权之和为 \(q_u\)\(v\) 同理。则:

\[\sum_{u=1}^nw(u)=\sum_{(u,v)\in E}b(u,v)(p_uq_v+p_vq_u) \]

  综上,求出这一系列式子,问题就以 \(\mathcal O(n)\) 的复杂度解决啦!

\(\mathcal{Code}\)

/* Clearink */

#include <cstdio>

inline char fgc () {
	static char buf[1 << 17], *p = buf, *q = buf;
	return p == q && ( q = buf + fread ( p = buf, 1, 1 << 17, stdin ), p == q ) ? EOF : *p++;
}

inline int rint () {
	int x = 0; char s = fgc ();
	for ( ; s < '0' || '9' < s; s = fgc () );
	for ( ; '0' <= s && s <= '9'; s = fgc () ) x = x * 10 + ( s ^ '0' );
	return x;
}

const int MAXN = 1e6, MOD = 998244353;
int n, ecnt, head[MAXN + 5], val[MAXN + 5], fac[MAXN + 5], ifac[MAXN + 5];
int fa[MAXN + 5], siz[MAXN + 5], sum[MAXN + 5], dist[MAXN + 5]; // dist[i]==dist(1,i).
int g[MAXN + 5];

inline int mul ( const long long a, const int b ) { return a * b % MOD; }
inline int sub ( int a, const int b ) { return ( a -= b ) < 0 ? a + MOD : a; }
inline int add ( int a, const int b ) { return ( a += b ) < MOD ? a : a - MOD; }
inline void addeq ( int& a, const int b ) { ( a += b ) >= MOD && ( a -= MOD, 0 ); }

struct Edge { int to, cst, nxt; } graph[MAXN * 2 + 5];

inline void link ( const int s, const int t, const int c ) {
	graph[++ecnt] = { t, c, head[s] };
	head[s] = ecnt;
}

inline int qkpow ( int a, int b ) {
	int ret = 1;
	for ( ; b; a = mul ( a, a ), b >>= 1 ) ret = mul ( ret, b & 1 ? a : 1 );
	return ret;
}

inline void init () {
	fac[0] = 1;
	for ( int i = 1; i <= n; ++i ) fac[i] = mul ( i, fac[i - 1] );
	ifac[n] = qkpow ( fac[n], MOD - 2 );
	for ( int i = n - 1; ~i; -- i ) ifac[i] = mul ( i + 1ll, ifac[i + 1] );
}

inline int inv ( const int x ) { return mul ( fac[x - 1], ifac[x] ); }

inline int comb ( const int n, const int m ) {
	return n < m ? 0 : mul ( fac[n], mul ( ifac[m], ifac[n - m] ) );
}

inline void dfs ( const int u ) {
	siz[u] = 1, sum[u] = val[u];
	for ( int i = head[u], v; i; i = graph[i].nxt ) {
		if ( ( v = graph[i].to ) ^ fa[u] ) {
			dist[v] = add ( dist[fa[v] = u], graph[i].cst );
			dfs ( v ), siz[u] += siz[v], addeq ( sum[u], sum[v] );
		}
	}
}

int main () {
	n = rint ();
	int vs = 0;
	for ( int i = 1; i <= n; ++i ) vs = add ( vs, val[i] = rint () );
	for ( int i = 1, u, v, w; i < n; ++i ) {
		u = rint (), v = rint (), w = rint ();
		link ( u, v, w ), link ( v, u, w );
	}
	init (), dfs ( 1 );
	int S = 0;
	for ( int u = 1; u <= n; ++u ) {
		for ( int i = head[u], v; i; i = graph[i].nxt ) {
			if ( ( v = graph[i].to ) ^ fa[u] ) {
				addeq ( S, mul ( graph[i].cst,
					add ( mul ( siz[v], sub ( vs, sum[v] ) ),
						mul ( n - siz[v], sum[v] ) ) ) );
			}
		}
	}
	for ( int i = 2; i <= n; ++i ) addeq ( g[n - 1], mul ( val[i], dist[i] ) );
	for ( int i = n - 2; i; --i ) {
		g[i] = add ( mul ( comb ( n - 1, i ), g[n - 1] ),
			mul ( comb ( n - 2, i ), sub ( S, g[n - 1] ) ) );
	}
	int ans = 0;
	for ( int i = 1; i < n; ++i ) {
		ans = ( i & 1 ? add : sub )( ans,
			add ( mul ( comb ( n - 1, i ), g[n - 1] ), mul ( inv ( i ), g[i] ) ) );
	}
	printf ( "%d\n", mul ( ans, inv ( n ) ) );
	return 0;
}
posted @ 2020-12-20 23:17  Rainybunny  阅读(145)  评论(0编辑  收藏  举报