[2022牛客暑期多校训练营3 D] Directed

题目简介

给定一棵 \(n\) 个点的有根树,1 号点为根。

随机选择 \(k\) 条边定向,使得边从儿子指向父亲。

问从 \(s\) 号点出发,随机游走,最早走到 1 号点的期望步数。

\(n \le 10^6, k < n\)

正解

感觉这道题像是一个拼题。

前置知识

我是觉得这篇博客写的真的很好,墙裂安利一波。

对于有根树设 \(f_i\) 为从 \(i\) 出发第一次到达其父亲的期望步数,它与 \(i\) 号点的子树大小 \(siz_i\) 的关系为:

\[f_i = 2 \times siz_i - 1 \]

现在就好做了,很容易先算出不删边的答案。

如果从 \(s\) 点出发到达根的路径上的点(根除外)为 \(p_1, p_2 \cdots p_m\)

这部分答案为:

\[\sum_{i = 1}^m{f_{p_i}} \]

再来看对 \(u\) 号点连向父亲的条边定向后,答案要减去多少。

这条边会对它到根路径上的所有 \(p_i\) 点的 \(f_{p_i}\) 产生影响(除非有碰到另外一个被删掉的边,这之上则是另外一条边产生影响)。

对于某一个被影响到的 \(p_i\)\(f_{p_i}\) 要减掉 \(2 \times siz_u\)

\(u\) 父边能够影响到 \(p_i\) 点的概率为:

\[\frac {\binom {n - 1 - (dep_{u} - dep_{p_i})} {k - 1}} {\binom {n - 1} {k}} \]

(即 \(u\)\(p_i\) 路径上除了 \(u\) 父边,其他所有边都没有被选上的概率)

虽然 \(u\) 影响了许多点,但是他们的深度是连续的,就可以用某个二项式系数的公式 \(O(1)\) 算掉,就没啦。

代码

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6 + 6, mod = 998244353;
int n, k, s;
int fa[N], siz[N], dep[N];
int fac[N], ifac[N], is[N];
int head[N], ecnt;
struct edge {
    int nex, to;
} e[N << 1];

int fpm(int x, int y) {
    int res = 1;
    while(y) {
        if(y & 1) res = 1LL * res * x % mod;
        x = 1LL * x * x % mod, y >>= 1;
    }
    return res;
}
int perm(int x, int y) { return 1LL * fac[x] * ifac[x - y] % mod; }
int comb(int x, int y) { return x >= y ? 1LL * perm(x, y) * ifac[y] % mod : 0; }

int ans = 0;
void dfs(int u) {
    siz[u] = 1;
    dep[u] = dep[fa[u]] + 1;
    for(int i = head[u], v; i; i = e[i].nex) {
        v = e[i].to;
        if(v == fa[u]) continue;
        fa[v] = u, dfs(v), siz[u] += siz[v];
    }
}

void work(int u, int mdep) {
    if(fa[u] != 1) {
        /*
        for(int i = 2; i <= mdep; ++i)
            if(n - 1 - (dep[u] - i) >= k - 1)
                ans = (ans - 2LL * siz[u] * comb(n - 1 - (dep[u] - i), k - 1) % mod * 
                fpm(comb(n - 1, k), mod - 2) % mod + mod) % mod;
        */
        int l = 2, r = mdep;
        l = max(2, k - 1 + dep[u] + 1 - n);
        if(l <= r) {
            l = n - 1 - (dep[u] - l), r = n - 1 - (dep[u] - r);
            ans = (ans - 2LL * siz[u] * (comb(r + 1, k) - comb(l, k) + mod) % mod * 
            fpm(comb(n - 1, k), mod - 2) % mod + mod) % mod;
        }
    }
    if(is[u]) mdep = max(mdep, dep[u]);
    for(int i = head[u], v; i; i = e[i].nex) {
        v = e[i].to;
        if(v == fa[u]) continue;
        work(v, mdep);
    }
}

int main() {
    scanf("%d %d %d", &n, &k, &s);
    fac[0] = 1;
    for(int i = 1; i <= n; ++i) fac[i] = 1LL * i * fac[i - 1] % mod;
    ifac[n] = fpm(fac[n], mod - 2);
    for(int i = n; i; --i) ifac[i - 1] = 1LL * i * ifac[i] % mod;
    for(int i = 1, u, v; i < n; ++i) {
        scanf("%d %d", &u, &v);
        e[++ecnt] = (edge){head[u], v}, head[u] = ecnt;
        e[++ecnt] = (edge){head[v], u}, head[v] = ecnt;
    }
    dfs(1);
    if(s == 1) {
        puts("0");
        return 0;
    }
    int u = s;
    while(fa[u] != 1) {
        ans = (ans + siz[u] * 2 - 1) % mod;
        is[u] = 1;
        u = fa[u];
    }
    ans = (ans + siz[u] * 2 - 1) % mod;
    is[u] = 1;
    work(u, dep[u]);
    printf("%d\n", ans);
    system("pause");
    return 0;
}
posted @ 2022-08-08 21:41  Lskkkno1  阅读(86)  评论(0编辑  收藏  举报