Live2D

Solution -「PKUWC 2018」「洛谷 P5298」Minimax

\(\mathscr{Description}\)

  Link.

  给定一棵二叉树,每片叶子有一个权值,所有权值互不相同。每个非叶结点 \(u\) 有一个概率 \(p_u\in(0,1)\),表示 \(u\) 的权值以 \(p_u\) 的概率取儿子最大权值,以 \(1-p_u\) 的概率取儿子最小权值。求根节点取到每种权值的概率(以一定形式压缩输出)。答案模 \(998244353\)

\(\mathscr{Solution}\)

  令 \(f(u,i)\) 表示 \(u\) 处取到全局第 \(i\) 大的权值的概率,设 \(u\) 的左右儿子为 \(l,r\),显然

\[f(u,i)=f(l,i)\left(p_u\sum_{j<i}f(r,j)+(1-p_u)\sum_{j>i}f(r,j)\right)+f(r,i)\left(p_u\sum_{j<i}f(l,j)+(1-p_u)\sum_{j>i}f(l,j)\right). \]

这是个左右区间的交叉贡献,用权值线段树维护 \(f(u)\),我们居然能直接将 \(f(l)\)\(f(r)\) 的树“混合”直接求出 \(f(u)\)

  注意,例如左对右有贡献,直接在右子树上打乘法或加法之类的标记是不良的——不同时间戳下的标记合并方式不同。因而,我们可以暴力地保证,当且仅当线段树结点 \(u\) 的后代不会被打上当前时间戳的标记(即,\(u\) 子树内的贡献因子都一样了),我们才给它打乘法标记;否则仅累加系数递归传递。

  所有贡献完成后,我们再把 \(f'(l)\)\(f'(r)\) 合并(对应点贡献相加),就得到了 \(f(u)\)。本质上就是做了两次线段树合并,所以复杂度是 \(\mathcal O(n\log n)\)

  那么这件事情告诉我们线段树无所不能,所有看似暴力的 DP 都拿来试一试。(

\(\mathscr{Code}\)

/*+Rainybunny+*/

#include <bits/stdc++.h>

#define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i)
#define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i)

const int MAXN = 3e5, MOD = 998244353, INV1E4 = 796898467;
int n, siz[MAXN + 5], ch[MAXN + 5][2], val[MAXN + 5];
int mxv, dc[MAXN + 5], root[MAXN + 5];

inline int mul(const int u, const int v) { return 1ll * u * v % MOD; }
inline void subeq(int& u, const int v) { (u -= v) < 0 && (u += MOD); }
inline int sub(int u, const int v) { return (u -= v) < 0 ? u + MOD : u; }
inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); }
inline int add(int u, const int v) { return (u += v) < MOD ? u : u - MOD; }

struct SegmentTree {
    static const int MAXND = 3e6;
    int node, ch[MAXND][2], sum[MAXND], tag[MAXND];

    inline void pushup(const int u) {
        sum[u] = add(sum[ch[u][0]], sum[ch[u][1]]);
    }

    inline void pushml(const int u, const int v) {
        assert(u);
        sum[u] = mul(sum[u], v), tag[u] = mul(tag[u], v);
    }

    inline void pushdn(const int u) {
        if (tag[u] != 1) {
            if (ch[u][0]) pushml(ch[u][0], tag[u]);
            if (ch[u][1]) pushml(ch[u][1], tag[u]);
            tag[u] = 1;
        }
    }

    inline void insert(int& u, const int l, const int r, const int x) {
        u = ++node, tag[u] = sum[u] = 1;
        if (l == r) return ;
        int mid = l + r >> 1;
        if (x <= mid) insert(ch[u][0], l, mid, x);
        else insert(ch[u][1], mid + 1, r, x);
    }

    inline void mix(const int u, const int v,
      const int su, const int sv, const int p) {
        if (!u && !v) return ;
        if (u && !v) return pushml(u, su);
        if (v && !u) return pushml(v, sv);
        if (!ch[u][0] && !ch[u][1]) return pushml(u, su), pushml(v, sv);
        pushdn(u), pushdn(v);
        int ul = sum[ch[u][0]], ur = sum[ch[u][1]];
        int vl = sum[ch[v][0]], vr = sum[ch[v][1]], q = sub(1, p);
        mix(ch[u][0], ch[v][0], add(su, mul(q, vr)), add(sv, mul(q, ur)), p);
        mix(ch[u][1], ch[v][1], add(su, mul(p, vl)), add(sv, mul(p, ul)), p);
        pushup(u), pushup(v);
    }

    inline void merge(int& u, const int v) {
        if (!u || !v) return void(u |= v);
        if (!ch[u][0] && !ch[u][1]) return addeq(sum[u], sum[v]);
        pushdn(u), pushdn(v), addeq(sum[u], sum[v]);
        merge(ch[u][0], ch[v][0]), merge(ch[u][1], ch[v][1]);
    }

    inline int answer(const int u, const int l, const int r) {
        if (l == r) return mul(mul(l, dc[l]), mul(sum[u], sum[u]));
        int mid = l + r >> 1, ret = 0; pushdn(u);
        addeq(ret, answer(ch[u][0], l, mid));
        addeq(ret, answer(ch[u][1], mid + 1, r));
        return ret;
    }
} sgt;

inline void solve(const int u) {
    if (!ch[u][0]) {
        val[u] = std::lower_bound(dc + 1, dc + mxv + 1, val[u]) - dc;
        sgt.insert(root[u], 1, mxv, val[u]);
    } else if (!ch[u][1]) {
        solve(ch[u][0]), root[u] = root[ch[u][0]];
    } else {
        solve(ch[u][0]), solve(ch[u][1]);
        sgt.mix(root[ch[u][0]], root[ch[u][1]], 0, 0, val[u]);
        sgt.merge(root[u] = root[ch[u][0]], root[ch[u][1]]);
    }
}

int main() {
    std::ios::sync_with_stdio(false), std::cin.tie(0);

    std::cin >> n;
    rep (i, 1, n) {
        int fa; std::cin >> fa;
        if (fa) ch[fa][!!ch[fa][0]] = i;
    }
    rep (i, 1, n) {
        std::cin >> val[i];
        if (ch[i][0]) val[i] = mul(val[i], INV1E4);
        else dc[++mxv] = val[i];
    }

    std::sort(dc + 1, dc + mxv + 1);
    mxv = std::unique(dc + 1, dc + mxv + 1) - dc - 1;
    solve(1);

    std::cout << sgt.answer(root[1], 1, mxv) << '\n';
    return 0;
}

posted @ 2022-02-23 22:48  Rainybunny  阅读(49)  评论(0编辑  收藏  举报