[2022牛客暑期多校训练营3 D] Directed
题目简介
给定一棵 \(n\) 个点的有根树,1 号点为根。
随机选择 \(k\) 条边定向,使得边从儿子指向父亲。
问从 \(s\) 号点出发,随机游走,最早走到 1 号点的期望步数。
\(n \le 10^6, k < n\)
正解
感觉这道题像是一个拼题。
前置知识
我是觉得这篇博客写的真的很好,墙裂安利一波。
对于有根树设 \(f_i\) 为从 \(i\) 出发第一次到达其父亲的期望步数,它与 \(i\) 号点的子树大小 \(siz_i\) 的关系为:
\[f_i = 2 \times siz_i - 1
\]
现在就好做了,很容易先算出不删边的答案。
如果从 \(s\) 点出发到达根的路径上的点(根除外)为 \(p_1, p_2 \cdots p_m\)。
这部分答案为:
\[\sum_{i = 1}^m{f_{p_i}}
\]
再来看对 \(u\) 号点连向父亲的条边定向后,答案要减去多少。
这条边会对它到根路径上的所有 \(p_i\) 点的 \(f_{p_i}\) 产生影响(除非有碰到另外一个被删掉的边,这之上则是另外一条边产生影响)。
对于某一个被影响到的 \(p_i\),\(f_{p_i}\) 要减掉 \(2 \times siz_u\)。
\(u\) 父边能够影响到 \(p_i\) 点的概率为:
\[\frac {\binom {n - 1 - (dep_{u} - dep_{p_i})} {k - 1}} {\binom {n - 1} {k}}
\]
(即 \(u\) 到 \(p_i\) 路径上除了 \(u\) 父边,其他所有边都没有被选上的概率)
虽然 \(u\) 影响了许多点,但是他们的深度是连续的,就可以用某个二项式系数的公式 \(O(1)\) 算掉,就没啦。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 6, mod = 998244353;
int n, k, s;
int fa[N], siz[N], dep[N];
int fac[N], ifac[N], is[N];
int head[N], ecnt;
struct edge {
int nex, to;
} e[N << 1];
int fpm(int x, int y) {
int res = 1;
while(y) {
if(y & 1) res = 1LL * res * x % mod;
x = 1LL * x * x % mod, y >>= 1;
}
return res;
}
int perm(int x, int y) { return 1LL * fac[x] * ifac[x - y] % mod; }
int comb(int x, int y) { return x >= y ? 1LL * perm(x, y) * ifac[y] % mod : 0; }
int ans = 0;
void dfs(int u) {
siz[u] = 1;
dep[u] = dep[fa[u]] + 1;
for(int i = head[u], v; i; i = e[i].nex) {
v = e[i].to;
if(v == fa[u]) continue;
fa[v] = u, dfs(v), siz[u] += siz[v];
}
}
void work(int u, int mdep) {
if(fa[u] != 1) {
/*
for(int i = 2; i <= mdep; ++i)
if(n - 1 - (dep[u] - i) >= k - 1)
ans = (ans - 2LL * siz[u] * comb(n - 1 - (dep[u] - i), k - 1) % mod *
fpm(comb(n - 1, k), mod - 2) % mod + mod) % mod;
*/
int l = 2, r = mdep;
l = max(2, k - 1 + dep[u] + 1 - n);
if(l <= r) {
l = n - 1 - (dep[u] - l), r = n - 1 - (dep[u] - r);
ans = (ans - 2LL * siz[u] * (comb(r + 1, k) - comb(l, k) + mod) % mod *
fpm(comb(n - 1, k), mod - 2) % mod + mod) % mod;
}
}
if(is[u]) mdep = max(mdep, dep[u]);
for(int i = head[u], v; i; i = e[i].nex) {
v = e[i].to;
if(v == fa[u]) continue;
work(v, mdep);
}
}
int main() {
scanf("%d %d %d", &n, &k, &s);
fac[0] = 1;
for(int i = 1; i <= n; ++i) fac[i] = 1LL * i * fac[i - 1] % mod;
ifac[n] = fpm(fac[n], mod - 2);
for(int i = n; i; --i) ifac[i - 1] = 1LL * i * ifac[i] % mod;
for(int i = 1, u, v; i < n; ++i) {
scanf("%d %d", &u, &v);
e[++ecnt] = (edge){head[u], v}, head[u] = ecnt;
e[++ecnt] = (edge){head[v], u}, head[v] = ecnt;
}
dfs(1);
if(s == 1) {
puts("0");
return 0;
}
int u = s;
while(fa[u] != 1) {
ans = (ans + siz[u] * 2 - 1) % mod;
is[u] = 1;
u = fa[u];
}
ans = (ans + siz[u] * 2 - 1) % mod;
is[u] = 1;
work(u, dep[u]);
printf("%d\n", ans);
system("pause");
return 0;
}