Loading

【题解】P5298 [PKUWC2018]Minimax

P5298 [PKUWC2018]Minimax

思路

线段树合并优化树形 dp.

值域 1e9 首先考虑离散化。然后发现需要维护每种权值的出现概率,于是可以考虑到一个简单的树形 dp:

\(f[i][j]\) 为结点 \(i\) 出现第 \(j\) 大的权值的概率,分类讨论:

  1. 该结点为叶结点
    令该结点的权值出现概率为 \(1\)

  2. 该结点只有一个子结点
    直接继承子结点的概率

  3. 该结点有两个子结点
    \(f[i][j] = f[l][j] \cdot [p_i \sum\limits_{k = 1}^{j - 1} f[r][k] + (1 - p_i) \sum\limits_{k = j + 1}^m f[r][k]] + f[r][j] \cdot [p_i \sum\limits_{k = 1}^{j - 1} f[l][k] + (1 - p_i) \sum\limits_{k = j + 1}^m f[l][k]\ ]\)

注意到第三种情况相当于:对于每个结点,在值域上求前缀和 / 后缀和,可以用一棵线段树直接维护。

这个 dp 实际上对于每个结点维护整个值域的所有信息,也就是对于每个结点都维护值域中每个权值的出现概率,于是可以考虑线段树合并优化。

具体地,观察转移方程。发现对于一个权值,实际上相当于在原本的概率基础上乘上一个系数,这个系数是子结点值域左右两侧的概率区间和。这个和可以在线段树合并递归的时候直接维护:

int merge(int a, int b, int l, int r, int mula, int mulb, int p)
{
    if ((!a) && (!b)) return 0;
    if (!a)
    {
        push_mul(b, mulb);
        return b;
    }
    if (!b)
    {
        push_mul(a, mula);
        return a;
    }
    push_down(a), push_down(b);
    int mid = (l + r) >> 1;
    int lsx = sum[ls[a]], rsx = sum[rs[a]], lsy = sum[ls[b]], rsy = sum[rs[b]];
    ls[a] = merge(ls[a], ls[b], l, mid, (mula + rsy * (1 - p + mod) % mod) % mod, (mulb + rsx * (1 - p + mod) % mod) % mod, p);
    rs[a] = merge(rs[a], rs[b], mid + 1, r, (mula + lsy * p % mod) % mod, (mulb + lsx * p % mod) % mod, p);
    push_up(a);
    return a;
}

于是直接上线段树合并就做完了。

时间复杂度约为 \(O(n \log n)\)

代码

#include <cstdio>
#include <algorithm>
using namespace std;

#define int long long

const int maxn = 3e5 + 5;
const int t_sz = maxn * 40;
const int mod = 998244353;

int n, cnt;
int m, seq[maxn];
int rt[maxn], resp[maxn];
int fa[maxn], son[maxn][2], deg[maxn], p[maxn];
int ls[t_sz], rs[t_sz], sum[t_sz], mul[t_sz];

inline int read()
{
    int res = 0;
    char ch = getchar();
    while ((ch < '0') || (ch > '9')) ch = getchar();
    while ((ch >= '0') && (ch <= '9'))
    {
        res = res * 10 + ch - '0';
        ch = getchar();
    }
    return res;
}

int qpow(int base, int power)
{
    int res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

void push_up(int k) { sum[k] = (sum[ls[k]] + sum[rs[k]]) % mod; }

void push_mul(int k, int w)
{
    if (!k) return;
    sum[k] = sum[k] * w % mod;
    mul[k] = mul[k] * w % mod;
}

void push_down(int k)
{
    if (mul[k] == 1) return;
    if (ls[k]) push_mul(ls[k], mul[k]);
    if (rs[k]) push_mul(rs[k], mul[k]);
    mul[k] = 1;
}

int new_node()
{
    int res = ++cnt;
    ls[res] = rs[res] = sum[res] = 0, mul[res] = 1;
    return res;
}

void update(int &k, int l, int r, int p, int w)
{
    if (!k) k = new_node();
    if (l == r)
    {
        sum[k] = w;
        return;
    }
    push_down(k);
    int mid = (l + r) >> 1;
    if (p <= mid) update(ls[k], l, mid, p, w);
    else update(rs[k], mid + 1, r, p, w);
    push_up(k);
}

int merge(int a, int b, int l, int r, int mula, int mulb, int p)
{
    if ((!a) && (!b)) return 0;
    if (!a)
    {
        push_mul(b, mulb);
        return b;
    }
    if (!b)
    {
        push_mul(a, mula);
        return a;
    }
    push_down(a), push_down(b);
    int mid = (l + r) >> 1;
    int lsx = sum[ls[a]], rsx = sum[rs[a]], lsy = sum[ls[b]], rsy = sum[rs[b]];
    ls[a] = merge(ls[a], ls[b], l, mid, (mula + rsy * (1 - p + mod) % mod) % mod, (mulb + rsx * (1 - p + mod) % mod) % mod, p);
    rs[a] = merge(rs[a], rs[b], mid + 1, r, (mula + lsy * p % mod) % mod, (mulb + lsx * p % mod) % mod, p);
    push_up(a);
    return a;
}

void dfs1(int u)
{
    if (!deg[u]) update(rt[u], 1, m, p[u], 1);
    else if (deg[u] == 1) dfs1(son[u][0]), rt[u] = rt[son[u][0]];
    else dfs1(son[u][0]), dfs1(son[u][1]), rt[u] = merge(rt[son[u][0]], rt[son[u][1]], 1, m, 0, 0, p[u]);
}

void dfs2(int u, int l, int r)
{
    if (!u) return;
    if (l == r) { resp[l] = sum[u]; return; }
    push_down(u);
    int mid = (l + r) >> 1;
    dfs2(ls[u], l, mid);
    dfs2(rs[u], mid + 1, r);
}

signed main()
{
    n = read();
    for (int i = 1; i <= n; i++)
    {
        fa[i] = read();
        if (fa[i]) son[fa[i]][deg[fa[i]]++] = i;
    }
    // for (int i = 1; i <= n; i++) printf("%d %d\n", son[i][0], son[i][1]);
    for (int i = 1; i <= n; i++)
    {
        p[i] = read();
        if (!deg[i]) seq[++m] = p[i];
        else p[i] = p[i] * qpow(10000, mod - 2) % mod;
    }
    sort(seq + 1, seq + m + 1);
    for (int i = 1; i <= n; i++)
        if (!deg[i]) p[i] = lower_bound(seq + 1, seq + m + 1, p[i]) - seq;
    // for (int i = 1; i <= n; i++) printf("%d ", p[i]);
    dfs1(1);
    dfs2(rt[1], 1, m);
    int ans = 0;
    // for (int i = 1; i <= m; i++) printf("%d ", resp[i]);
    for (int i = 1; i <= m; i++) ans = (ans + i * seq[i] % mod * resp[i] % mod * resp[i] % mod) % mod;
    printf("%lld\n", ans % mod);
    return 0;
}
posted @ 2022-12-28 04:23  kymru  阅读(60)  评论(0编辑  收藏  举报