LOJ 2537 「PKUWC2018」Minimax
BZOJ 5461。
线段树合并优化$dp$。
假设所有离散之后的权值$\in [1, m]$,对于一个点$x$它的权值是$i$的概率是$f(x, i)$,那么
1、假如这个点只有一个儿子$y$,那么$f(x, i) = f(y, i)$。
2、假如这个点有两个儿子$y, z$,那么
$$f(x, i) = f(y, i)\sum_{j = 1}^{m}f(z, j)(p_x[i \geq j] + (1 - p_x)[i \leq j]) + f(z, i)\sum_{j = 1}^{m}f(y, j)(p_x[i \geq j] + (1 - p_x)[i \leq j])$$
看上去$f(y, i) * f(z, i)$这一项被算了两次,但是题目保证了所有权值互不相同,所以似乎并没有问题。
$$f(x, i) = f(y, i)(p_x\sum_{j = 1}^{i}f(z, j) + (1 - p_x)\sum_{j = i}^{m}f(z, j)) + f(z, i)(p_x\sum_{j = 1}^{i}f(y, j) + (1 - p_x)\sum_{j = i}^{m}f(y, j))$$
这样子就把$f(x, i)$表示成了$f(y, i)$乘上$f(z)$的贡献加上$f(z, i)$乘上$f(y)$的贡献,对于一个特定的$x$,$p_x$是一个定值,只要在线段树合并的时候算出向下的贡献累加到结点中即可。
注意到当合并的时候会出现一个结点是空结点的情况,这时候其实相当于整段区间乘上了一段定值,可以采用在线段树上打标记的形式来维护。
以后线段树合并都直接合并到原来的结点上去,不要开新节点,这样子好写而且空间小,注意合并之后原来的权值可能会发生变化,所以开一个变量把原来的值记录下来。
时间复杂度$O(nlogn)$。
Code:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int N = 3e5 + 5; const int M = 1e7 + 5; const ll P = 998244353LL; int n, fa[N], son[N][3], m = 0; ll a[N], buc[N]; bool isLeaf[N]; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9'|| ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } template <typename T> inline void inc(T &x, T y) { x += y; if (x >= P) x -= P; } template <typename T> inline void sub(T &x, T y) { x -= y; if (x < 0) x += P; } inline ll fpow(ll x, ll y) { ll res = 1; for (x %= P; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } void dfs(int x) { isLeaf[x] = 1; for (int i = 1; i <= son[x][0]; i++) { isLeaf[x] = 0; dfs(son[x][i]); } } namespace SegT { struct Node { int lc, rc; ll sum, tag; } s[M]; #define lc(p) s[p].lc #define rc(p) s[p].rc #define sum(p) s[p].sum #define tag(p) s[p].tag #define mid ((l + r) >> 1) int root[N], nodeCnt = 0; inline void down(int p) { if (tag(p) == 1LL) return; if (lc(p)) { (sum(lc(p)) *= tag(p)) %= P; (tag(lc(p)) *= tag(p)) %= P; } if (rc(p)) { (sum(rc(p)) *= tag(p)) %= P; (tag(rc(p)) *= tag(p)) %= P; } tag(p) = 1; } void ins(int &p, int l, int r, int x) { p = ++nodeCnt, tag(p) = sum(p) = 1; if (l == r) return; if (x <= mid) ins(lc(p), l, mid, x); else ins(rc(p), mid + 1, r, x); } int merge(int u, int v, ll su, ll sv, ll k) { if (!u) { (tag(v) *= su) %= P; (sum(v) *= su) %= P; return v; } if (!v) { (tag(u) *= sv) %= P; (sum(u) *= sv) %= P; return u; } down(u), down(v); ll k1 = k, k2 = 1, rsu = sum(rc(u)), lsu = sum(lc(u)), rsv = sum(rc(v)), lsv = sum(lc(v)); sub(k2, k); lc(u) = merge(lc(u), lc(v), (su + k2 * rsu % P) % P, (sv + k2 * rsv % P) % P, k); rc(u) = merge(rc(u), rc(v), (su + k1 * lsu % P) % P, (sv + k1 * lsv % P) % P, k); sum(u) = 0; inc(sum(u), sum(lc(u))), inc(sum(u), sum(rc(u))); return u; } ll query(int p, int l, int r, int x) { if (l == r) return sum(p); down(p); if (x <= mid) return query(lc(p), l, mid, x); else return query(rc(p), mid + 1, r, x); } void deb(int rt) { for (int i = 1; i <= m; i++) printf("%lld%c", query(rt, 1, m, i), i == m ? '\n' : ' '); } } using namespace SegT; void solve(int x) { if (isLeaf[x]) return; for (int i = 1; i <= son[x][0]; i++) solve(son[x][i]); if (son[x][0] == 1) root[x] = root[son[x][1]]; if (son[x][0] == 2) root[x] = SegT :: merge(root[son[x][1]], root[son[x][2]], 0, 0, a[x] % P); } int main() { #ifndef ONLINE_JUDGE freopen("5.in", "r", stdin); #endif read(n); for (int i = 1; i <= n; i++) { read(fa[i]); if (fa[i] != 0) son[fa[i]][++son[fa[i]][0]] = i; } dfs(1); ll inv10000 = fpow(10000LL, P - 2); for (int i = 1; i <= n; i++) { read(a[i]); if (isLeaf[i]) buc[++m] = a[i]; else a[i] = a[i] * inv10000 % P; } sort(buc + 1, buc + 1 + m); for (int i = 1; i <= n; i++) { if (!isLeaf[i]) continue; a[i] = lower_bound(buc + 1, buc + 1 + m, a[i]) - buc; ins(root[i], 1, m, a[i]); } solve(1); ll ans = 0; for (int i = 1; i <= m; i++) { ll p = query(root[1], 1, m, i); inc(ans, 1LL * i * p % P * p % P * (buc[i] % P) % P); } printf("%lld\n", ans); /* for (int i = 1; i <= m; i++) { ll p = query(root[1], 1, m, i); printf("%lld%c", p, i == m ? '\n' : ' '); } */ return 0; }