【题解】P5298 [PKUWC2018]Minimax
思路
线段树合并优化树形 dp.
值域 1e9 首先考虑离散化。然后发现需要维护每种权值的出现概率,于是可以考虑到一个简单的树形 dp:
设 \(f[i][j]\) 为结点 \(i\) 出现第 \(j\) 大的权值的概率,分类讨论:
-
该结点为叶结点
令该结点的权值出现概率为 \(1\) -
该结点只有一个子结点
直接继承子结点的概率 -
该结点有两个子结点
\(f[i][j] = f[l][j] \cdot [p_i \sum\limits_{k = 1}^{j - 1} f[r][k] + (1 - p_i) \sum\limits_{k = j + 1}^m f[r][k]] + f[r][j] \cdot [p_i \sum\limits_{k = 1}^{j - 1} f[l][k] + (1 - p_i) \sum\limits_{k = j + 1}^m f[l][k]\ ]\)
注意到第三种情况相当于:对于每个结点,在值域上求前缀和 / 后缀和,可以用一棵线段树直接维护。
这个 dp 实际上对于每个结点维护整个值域的所有信息,也就是对于每个结点都维护值域中每个权值的出现概率,于是可以考虑线段树合并优化。
具体地,观察转移方程。发现对于一个权值,实际上相当于在原本的概率基础上乘上一个系数,这个系数是子结点值域左右两侧的概率区间和。这个和可以在线段树合并递归的时候直接维护:
int merge(int a, int b, int l, int r, int mula, int mulb, int p)
{
if ((!a) && (!b)) return 0;
if (!a)
{
push_mul(b, mulb);
return b;
}
if (!b)
{
push_mul(a, mula);
return a;
}
push_down(a), push_down(b);
int mid = (l + r) >> 1;
int lsx = sum[ls[a]], rsx = sum[rs[a]], lsy = sum[ls[b]], rsy = sum[rs[b]];
ls[a] = merge(ls[a], ls[b], l, mid, (mula + rsy * (1 - p + mod) % mod) % mod, (mulb + rsx * (1 - p + mod) % mod) % mod, p);
rs[a] = merge(rs[a], rs[b], mid + 1, r, (mula + lsy * p % mod) % mod, (mulb + lsx * p % mod) % mod, p);
push_up(a);
return a;
}
于是直接上线段树合并就做完了。
时间复杂度约为 \(O(n \log n)\)
代码
#include <cstdio>
#include <algorithm>
using namespace std;
#define int long long
const int maxn = 3e5 + 5;
const int t_sz = maxn * 40;
const int mod = 998244353;
int n, cnt;
int m, seq[maxn];
int rt[maxn], resp[maxn];
int fa[maxn], son[maxn][2], deg[maxn], p[maxn];
int ls[t_sz], rs[t_sz], sum[t_sz], mul[t_sz];
inline int read()
{
int res = 0;
char ch = getchar();
while ((ch < '0') || (ch > '9')) ch = getchar();
while ((ch >= '0') && (ch <= '9'))
{
res = res * 10 + ch - '0';
ch = getchar();
}
return res;
}
int qpow(int base, int power)
{
int res = 1;
while (power)
{
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
void push_up(int k) { sum[k] = (sum[ls[k]] + sum[rs[k]]) % mod; }
void push_mul(int k, int w)
{
if (!k) return;
sum[k] = sum[k] * w % mod;
mul[k] = mul[k] * w % mod;
}
void push_down(int k)
{
if (mul[k] == 1) return;
if (ls[k]) push_mul(ls[k], mul[k]);
if (rs[k]) push_mul(rs[k], mul[k]);
mul[k] = 1;
}
int new_node()
{
int res = ++cnt;
ls[res] = rs[res] = sum[res] = 0, mul[res] = 1;
return res;
}
void update(int &k, int l, int r, int p, int w)
{
if (!k) k = new_node();
if (l == r)
{
sum[k] = w;
return;
}
push_down(k);
int mid = (l + r) >> 1;
if (p <= mid) update(ls[k], l, mid, p, w);
else update(rs[k], mid + 1, r, p, w);
push_up(k);
}
int merge(int a, int b, int l, int r, int mula, int mulb, int p)
{
if ((!a) && (!b)) return 0;
if (!a)
{
push_mul(b, mulb);
return b;
}
if (!b)
{
push_mul(a, mula);
return a;
}
push_down(a), push_down(b);
int mid = (l + r) >> 1;
int lsx = sum[ls[a]], rsx = sum[rs[a]], lsy = sum[ls[b]], rsy = sum[rs[b]];
ls[a] = merge(ls[a], ls[b], l, mid, (mula + rsy * (1 - p + mod) % mod) % mod, (mulb + rsx * (1 - p + mod) % mod) % mod, p);
rs[a] = merge(rs[a], rs[b], mid + 1, r, (mula + lsy * p % mod) % mod, (mulb + lsx * p % mod) % mod, p);
push_up(a);
return a;
}
void dfs1(int u)
{
if (!deg[u]) update(rt[u], 1, m, p[u], 1);
else if (deg[u] == 1) dfs1(son[u][0]), rt[u] = rt[son[u][0]];
else dfs1(son[u][0]), dfs1(son[u][1]), rt[u] = merge(rt[son[u][0]], rt[son[u][1]], 1, m, 0, 0, p[u]);
}
void dfs2(int u, int l, int r)
{
if (!u) return;
if (l == r) { resp[l] = sum[u]; return; }
push_down(u);
int mid = (l + r) >> 1;
dfs2(ls[u], l, mid);
dfs2(rs[u], mid + 1, r);
}
signed main()
{
n = read();
for (int i = 1; i <= n; i++)
{
fa[i] = read();
if (fa[i]) son[fa[i]][deg[fa[i]]++] = i;
}
// for (int i = 1; i <= n; i++) printf("%d %d\n", son[i][0], son[i][1]);
for (int i = 1; i <= n; i++)
{
p[i] = read();
if (!deg[i]) seq[++m] = p[i];
else p[i] = p[i] * qpow(10000, mod - 2) % mod;
}
sort(seq + 1, seq + m + 1);
for (int i = 1; i <= n; i++)
if (!deg[i]) p[i] = lower_bound(seq + 1, seq + m + 1, p[i]) - seq;
// for (int i = 1; i <= n; i++) printf("%d ", p[i]);
dfs1(1);
dfs2(rt[1], 1, m);
int ans = 0;
// for (int i = 1; i <= m; i++) printf("%d ", resp[i]);
for (int i = 1; i <= m; i++) ans = (ans + i * seq[i] % mod * resp[i] % mod * resp[i] % mod) % mod;
printf("%lld\n", ans % mod);
return 0;
}