21牛客9G - Glass Balls (树上概率dp)

21牛客9G - Glass Balls (树上概率dp)

题目

source

题解

UPD:
队友写得题解真好啊,简单清晰,推荐:2021牛客暑期多校训练营9 G (树上概率dp 对于存在不合法的情况dp启示)
可以直接令分数为状态,从v转移一步到u,贡献是score[v]再加上v上面的那1个球,所以是score[u]=p*(score[v]+1)。关键在于这里的p是条件概率,即在所有合法情况中的占比。具体见上面的博客,这样只需一次dfs。


对于从\(u\)点出发掉到\(v\)点的球来说,它的贡献是\(dep[u]-dep[v]\)。设对于一个固定的局面,掉到\(v\)点的球的球的个数为\(cnt[v]\),那么所有球的贡献为(即该局面的分数)为:

\[\sum\limits_{i=1}^{n}{dep[i]-\sum\limits_{i=1}^{n}{cnt[i]\cdot dep[i]}} \]

因此,只要分别求出深度总和的期望每个结点掉下去球数的期望即可,可以用树上dp计算。这里有几点要注意的:

  • 局面有合法的情况和非法的情况,因此在转移状态时注意确保的是从合法的子状态以合法的过程转移过来。
  • 树上dp一般计算的是子树的结果,在合并统计答案时要考虑上子树外部分的影响,这也是为什么往往需要两个dfs计算down和up的原因。

从题目中可以容易推得,每个结点的子节点中至多只有一个结点不是“储存点”,否则就是非法的。

\(dp[i]\)​为\(i\)​的子树中到\(i\)​的球数的期望;\(down[i]\)​为点\(i\)​的子树为合法局面的概率;\(up[i]\)​为整棵树在点\(i\)​​为“储存点”时且除去了\(down[i]\)​​的合法概率。这里的\(up[i]\)​是为了\(i\)​子树中到点\(i\)​的球数的期望转换为整棵树中从\(i\)​掉下去的球数的期望,即\(cnt[i]=up[i] \cdot dp[i]\)​。

显然,深度总和的期望就是整棵树合法的概率乘上深度的总和,即\(down[1] \cdot \sum\limits_{i=1}^n{dep[i]}\)​。

\(down\)\(up\)的转移都比较简单,主要是\(dp\)的转移。设\(P\)为“储存点的概率”,\(t\)为点p子结点的个数。

  • 子结点都是“储存点”,且子节点都合法,此时\(u\)中只有1个球,这种情况的贡献为:

\[dp[u]=1 \cdot P^t \cdot \prod_{v {\rm 是}u{\rm的子节点}} {down[v]} \]

  • 子结点\(v\)​不是”储存点“,且子节点都合法,此时\(u\)中除了本身的1个球,还有来自\(dp[v]\)那么多的球,这种情况的贡献为:

\[dp[u]=dp[v]\cdot P^{t-1}\cdot (1-P)\cdot \prod_{v' {\rm 是}u{\rm的子节点且}v'\neq v}{down[v']}+1\cdot P^{t-1}\cdot (1-P)\cdot\prod_{v' {\rm 是}u{\rm的子节点}} {down[v']} \]

最终答案为:\(down[1] \cdot \sum\limits_{i=1}^n{dep[i]}-\sum\limits_{i=1}^{n}{up[i] \cdot dp[i]\cdot dep[i]}\)

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 5e5 + 10;
const int M = 998244353;
const double eps = 1e-5;

ll down[N];
ll up[N];
ll dp[N];
int dep[N];
ll po;
vector<int> np[N];

inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while(b) {
        if(b & 1) res = (res * a) % m;
        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void dfs(int p, int fa, int d) {
    dep[p] = d;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        dfs(nt, p, d + 1);
    }
}

void caldown(int p, int fa) {
    ll lp = 1;
    int num = 0;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        caldown(nt, p);
        lp = lp * down[nt] % M;
    }
    if(num)
        lp = lp * (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
    down[p] = lp;
}

void calup(int p, int fa) {
    int num = 0;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        up[nt] = down[1] * qpow(down[nt], M - 2, M) % M;
    }
    if(num) {
        ll tp = (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            up[nt] = up[nt] * qpow(tp, M - 2, M) % M;
            up[nt] = up[nt] * (qpow(po, num - 1, M) * (1 - po + M) % M * (num - 1) % M + qpow(po, num, M)) % M;
            calup(nt, p);
        }
    }
}

void solve(int p, int fa) {
    int num = 0;
    ll lp = 1;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        lp = lp * down[nt] % M;
        solve(nt, p);
    }
    dp[p] = qpow(po, num, M) * lp % M;
    if(num)
        for(int nt : np[p]) {
            if(nt == fa) continue;
            // 注意后面1的贡献
            // 不要写成(dp[nt] + 1) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M
            dp[p] += dp[nt] * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M + qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M;
            // 也可以写成
            // dp[p] += (dp[nt] + down[nt]) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M;
            
            dp[p] %= M;
        }
}

int main() {
    IOS;
    up[1] = 1;
    int n;
    cin >> n >> po;
    for(int i = 2; i <= n; i++) {
        int f;
        cin >> f;
        np[i].push_back(f);
        np[f].push_back(i);
    }
    dfs(1, 0, 1);
    caldown(1, 0);
    calup(1, 0);
    solve(1, 0);
    ll ans = 0;
    ll tp = down[1];
    for(int i = 1; i <= n; i++) {
        ans = (ans + (tp - up[i] * (dp[i]) % M + M) * dep[i] % M) % M;
    }
    cout << ans << endl;
}
posted @ 2021-08-15 14:20  limil  阅读(124)  评论(0编辑  收藏  举报