洛谷 P5298 [PKUWC2018]Minimax 题解
一、题目:
二、思路:
看到这道题第一眼,肯定是想从 \(\sum_{i=1}^m i\times V_i\times D_i^2\) 这个式子入手,但是很遗憾,由于这个式子太过诡异,我们无从下手。
因此我们不得不考虑把优化复杂度的重心放在如何快速求解每种可能的 \(V_i\) 的出现概率上。
先来考虑一种朴素的 DP。设状态 \(F(x,j)\) 表示 \(x\) 这个节点出现权值 \(j\) 的概率,再设左儿子为 \(ls\),右儿子为 \(rs\),那么显然有状态转移方程 $$F(x,j) = F(ls, j)\times p_x \times \sum_{k=1}^{j-1}F(rs,k)+F(rs,j)\times p_x\times \sum_{k=1}^{j-1}F(ls,k) + F(ls,j)\times (1-p_x)\times \sum_{k=j+1}^{maxx} F(rs,k) + F(rs,j)\times(1-p_x)\times \sum_{k=j+1}^{maxx} F(ls,k)$$
然后我们发现这个式子可以用线段树合并来实现快速转移,具体来说我们需要实现一个区间求和、区间乘法的权值线段树(可合并)。
然后注意开始的时候离散化一下,就没了。
三、注意:
我在实现的时候有一点小问题,就是我们发现动态开点的线段树在实现区间修改的时候,一般会在 pushdown 操作中写:如果没有左儿子,就新开一个节点;如果没有右儿子,就新开一个节点。
但是如果用这样的方法写线段树合并,就会遇到始终找不到回溯条件的情况(这是因为每次在 pushdown 的时候都会新开节点)。于是就很麻烦。但幸好这道题的 pushdown 不用这么干。我们把这种 pushdown 换一种写法:如果有左儿子,就把懒标记下放给左儿子;如果有右儿子,就把懒标记下方给右儿子。这样写为什么是正确的呢?这是因为如果一个节点本来就是空的,那它的 sum 值就是0,那自然不用乘了!
四、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return f * x;
}
const int maxn = 3e5 + 5, mod = 998244353;
int n;
int child[maxn][2], tot, maxx, rt[maxn];
long long w[maxn], b[maxn], p[maxn], Div, ans;
inline long long power(long long a, long long b) {
long long res = 1;
for (; b; b >>= 1) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
inline void lsh(void) {
sort(b + 1, b + tot + 1);
maxx = tot;
for (int i = 1; i <= n; ++ i) {
if (!child[i][0] && !child[i][1]) {
w[i] = lower_bound(b + 1, b + tot + 1, w[i]) - b;
}
}
}
namespace Tr {
const int maxm = 10000000;
int son[maxm][2], sz;
long long sum[maxm], mult[maxm];
inline int New(void) {
mult[++ sz] = 1;
return sz;
}
inline void pushup(int o) {
sum[o] = (sum[son[o][0]] + sum[son[o][1]]) % mod;
}
inline void Mult(int o, long long v) {
(mult[o] *= v) %= mod;
(sum[o] *= v) %= mod;
}
inline void pushdown(int o) {
if (son[o][0]) Mult(son[o][0], mult[o]);
if (son[o][1]) Mult(son[o][1], mult[o]);
mult[o] = 1;
}
void update(int &o, int l, int r, int q, long long v) {
if (!o) o = New();
if (l == r) { sum[o] += v; return; }
int mid = (l + r) >> 1;
if (q <= mid) update(son[o][0], l, mid, q, v);
if (q > mid) update(son[o][1], mid + 1, r, q, v);
pushup(o);
}
long long query(int o, int l, int r, int q) {
if (l == r) { return sum[o]; }
int mid = (l + r) >> 1;
pushdown(o);
if (q <= mid) return query(son[o][0], l, mid, q);
else return query(son[o][1], mid + 1, r, q);
}
int change(int o1, int o2, int l, int r, long long LM, long long RM, long long P) {
if (!o1 || !o2) {
if (!o1) Mult(o2, RM);
if (!o2) Mult(o1, LM);
return o1 + o2;
}
int mid = (l + r) >> 1;
pushdown(o1); pushdown(o2);
long long LX = sum[son[o1][0]], RX = sum[son[o1][1]], LY = sum[son[o2][0]], RY = sum[son[o2][1]];
son[o1][0] = change(son[o1][0], son[o2][0], l, mid, (LM + RY * (1 - P + mod) % mod) % mod, (RM + RX * (1 - P + mod) % mod) % mod, P);
son[o1][1] = change(son[o1][1], son[o2][1], mid + 1, r, (LM + LY * P % mod) % mod, (RM + LX * P % mod) % mod, P);
pushup(o1);
return o1;
}
void solve(int o, int l, int r) {
if (l == r) {
(ans += b[l] * l % mod * sum[o] % mod * sum[o] % mod) %= mod;
return;
}
int mid = (l + r) >> 1;
pushdown(o);
solve(son[o][0], l, mid);
solve(son[o][1], mid + 1, r);
}
}
void dfs(int x) {
if (!child[x][0] && !child[x][1]) {
Tr::update(rt[x], 1, maxx, w[x], 1);
return;
}
if (child[x][0] && !child[x][1]) {
dfs(child[x][0]);
swap(rt[x], rt[child[x][0]]);
return;
}
dfs(child[x][0]); dfs(child[x][1]);
rt[x] = Tr::change(rt[child[x][0]], rt[child[x][1]], 1, maxx, 0, 0, p[x]);
}
int main() {
n = read();
for (int i = 1; i <= n; ++ i) {
int f = read();
if (!child[f][0]) child[f][0] = i;
else child[f][1] = i;
}
Div = power(10000, mod - 2);
for (int i = 1; i <= n; ++ i) {
if (!child[i][0] && !child[i][1]) {
w[i] = read(); b[++ tot] = w[i];
}
else {
p[i] = read() * Div % mod;
}
}
lsh();
dfs(1);
Tr::solve(rt[1], 1, maxx);
printf("%lld\n", ans);
return 0;
}