Solution -「PKUWC 2018」「洛谷 P5298」Minimax
\(\mathscr{Description}\)
Link.
给定一棵二叉树,每片叶子有一个权值,所有权值互不相同。每个非叶结点 \(u\) 有一个概率 \(p_u\in(0,1)\),表示 \(u\) 的权值以 \(p_u\) 的概率取儿子最大权值,以 \(1-p_u\) 的概率取儿子最小权值。求根节点取到每种权值的概率(以一定形式压缩输出)。答案模 \(998244353\)。
\(\mathscr{Solution}\)
令 \(f(u,i)\) 表示 \(u\) 处取到全局第 \(i\) 大的权值的概率,设 \(u\) 的左右儿子为 \(l,r\),显然
\[f(u,i)=f(l,i)\left(p_u\sum_{j<i}f(r,j)+(1-p_u)\sum_{j>i}f(r,j)\right)+f(r,i)\left(p_u\sum_{j<i}f(l,j)+(1-p_u)\sum_{j>i}f(l,j)\right).
\]
这是个左右区间的交叉贡献,用权值线段树维护 \(f(u)\),我们居然能直接将 \(f(l)\) 和 \(f(r)\) 的树“混合”直接求出 \(f(u)\)!
注意,例如左对右有贡献,直接在右子树上打乘法或加法之类的标记是不良的——不同时间戳下的标记合并方式不同。因而,我们可以暴力地保证,当且仅当线段树结点 \(u\) 的后代不会被打上当前时间戳的标记(即,\(u\) 子树内的贡献因子都一样了),我们才给它打乘法标记;否则仅累加系数递归传递。
所有贡献完成后,我们再把 \(f'(l)\) 和 \(f'(r)\) 合并(对应点贡献相加),就得到了 \(f(u)\)。本质上就是做了两次线段树合并,所以复杂度是 \(\mathcal O(n\log n)\)。
那么这件事情告诉我们线段树无所不能,所有看似暴力的 DP 都拿来试一试。(
\(\mathscr{Code}\)
/*+Rainybunny+*/
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i)
#define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i)
const int MAXN = 3e5, MOD = 998244353, INV1E4 = 796898467;
int n, siz[MAXN + 5], ch[MAXN + 5][2], val[MAXN + 5];
int mxv, dc[MAXN + 5], root[MAXN + 5];
inline int mul(const int u, const int v) { return 1ll * u * v % MOD; }
inline void subeq(int& u, const int v) { (u -= v) < 0 && (u += MOD); }
inline int sub(int u, const int v) { return (u -= v) < 0 ? u + MOD : u; }
inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); }
inline int add(int u, const int v) { return (u += v) < MOD ? u : u - MOD; }
struct SegmentTree {
static const int MAXND = 3e6;
int node, ch[MAXND][2], sum[MAXND], tag[MAXND];
inline void pushup(const int u) {
sum[u] = add(sum[ch[u][0]], sum[ch[u][1]]);
}
inline void pushml(const int u, const int v) {
assert(u);
sum[u] = mul(sum[u], v), tag[u] = mul(tag[u], v);
}
inline void pushdn(const int u) {
if (tag[u] != 1) {
if (ch[u][0]) pushml(ch[u][0], tag[u]);
if (ch[u][1]) pushml(ch[u][1], tag[u]);
tag[u] = 1;
}
}
inline void insert(int& u, const int l, const int r, const int x) {
u = ++node, tag[u] = sum[u] = 1;
if (l == r) return ;
int mid = l + r >> 1;
if (x <= mid) insert(ch[u][0], l, mid, x);
else insert(ch[u][1], mid + 1, r, x);
}
inline void mix(const int u, const int v,
const int su, const int sv, const int p) {
if (!u && !v) return ;
if (u && !v) return pushml(u, su);
if (v && !u) return pushml(v, sv);
if (!ch[u][0] && !ch[u][1]) return pushml(u, su), pushml(v, sv);
pushdn(u), pushdn(v);
int ul = sum[ch[u][0]], ur = sum[ch[u][1]];
int vl = sum[ch[v][0]], vr = sum[ch[v][1]], q = sub(1, p);
mix(ch[u][0], ch[v][0], add(su, mul(q, vr)), add(sv, mul(q, ur)), p);
mix(ch[u][1], ch[v][1], add(su, mul(p, vl)), add(sv, mul(p, ul)), p);
pushup(u), pushup(v);
}
inline void merge(int& u, const int v) {
if (!u || !v) return void(u |= v);
if (!ch[u][0] && !ch[u][1]) return addeq(sum[u], sum[v]);
pushdn(u), pushdn(v), addeq(sum[u], sum[v]);
merge(ch[u][0], ch[v][0]), merge(ch[u][1], ch[v][1]);
}
inline int answer(const int u, const int l, const int r) {
if (l == r) return mul(mul(l, dc[l]), mul(sum[u], sum[u]));
int mid = l + r >> 1, ret = 0; pushdn(u);
addeq(ret, answer(ch[u][0], l, mid));
addeq(ret, answer(ch[u][1], mid + 1, r));
return ret;
}
} sgt;
inline void solve(const int u) {
if (!ch[u][0]) {
val[u] = std::lower_bound(dc + 1, dc + mxv + 1, val[u]) - dc;
sgt.insert(root[u], 1, mxv, val[u]);
} else if (!ch[u][1]) {
solve(ch[u][0]), root[u] = root[ch[u][0]];
} else {
solve(ch[u][0]), solve(ch[u][1]);
sgt.mix(root[ch[u][0]], root[ch[u][1]], 0, 0, val[u]);
sgt.merge(root[u] = root[ch[u][0]], root[ch[u][1]]);
}
}
int main() {
std::ios::sync_with_stdio(false), std::cin.tie(0);
std::cin >> n;
rep (i, 1, n) {
int fa; std::cin >> fa;
if (fa) ch[fa][!!ch[fa][0]] = i;
}
rep (i, 1, n) {
std::cin >> val[i];
if (ch[i][0]) val[i] = mul(val[i], INV1E4);
else dc[++mxv] = val[i];
}
std::sort(dc + 1, dc + mxv + 1);
mxv = std::unique(dc + 1, dc + mxv + 1) - dc - 1;
solve(1);
std::cout << sgt.answer(root[1], 1, mxv) << '\n';
return 0;
}