Loading

【题解】CF494C Helping People

思路

ST 表 + 树形 dp.

首先注意到 \(q \leq 5 \times 10^3\) 并且区间互不相交,这实际上给出了一个很强的结论:

考虑把每个区间向它的子区间连边,得到的一定是一个森林。

只要我们钦定存在一个概率为 \(0\) 的操作 \([1, n]\),那么最终得到的一定是一棵树。

构造这棵树可以考虑使用栈。

首先将所有区间优先按左端点升序排列,左端点相同则按右端点降序排列。此时任意区间的子区间均在其后。

使用一个栈存下当前区间的所有父区间。每右移一位就将所有右端点在当前区间左端点左侧的区间移除。

因为左端点升序,所以该区间一定是当前区间最小的父区间。

然后将当前区间入栈。

建树时间复杂度 \(O(n \log n)\)


考虑到 \(E(\max) \neq \max(E)\),直接将期望设进状态是不好做的。

只能考虑使用期望的定义,考虑每种最大值出现的概率。

钦定最大值求概率依旧很困难,考虑弱化成钦定最大值不超过某定值,这样最后的答案容斥一下就行。

于是考虑令 \(f[u][i]\) 为结点 \(u\) 子树中的操作完成后,该结点对应区间的最大值不超过 \(i\) 的概率。

转移容易推导:\(f[u][i] = p_u \prod\limits_{v \in son(u)} f[v][i - 1] + (1 - p_u) \prod\limits_{v \in son(u)} f[v][i]\).

但是第二维的值域是 \([1, 10^9]\),必须优化一下 /fn

注意到每次操作至多使最大值增加 \(1\),也就是 \(q\) 次操作后最大值的增量不超过 \(q\),于是令第二维变成在原有区间最大值基础上的增量就行。

\(w(u)\) 为结点 \(u\) 对应区间的初始最大值,有转移:

\(f[u][i] = p_u \prod\limits_{v \in son(u)} f[v][i + w(u) - w(v) - 1] + (1 - p_u) \prod\limits_{v \in son(u)} f[v][i + w(u) - w(v)]\).

特别地,当 \(i = 0\) 时只取后一半转移。

答案为 \(f[1][0] \cdot w(1) + \sum\limits_{i = 1}^q (f[1][i] - f[1][i - 1]) (w(1) + i)\).

于是可以在 \(O(n \log n + q^2)\) 的复杂度内解决此题。

代码

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

typedef double db;

const int maxn = 1e5 + 5;
const int maxq = 5e3 + 5;
const int lg_sz = 20;

struct item
{
    int l, r, mxv; db p;

    bool operator < (const item& rhs) const
    {
        if (l != rhs.l) return (l < rhs.l);
        return (r > rhs.r);
    }
} rg[maxq];

inline int max(const int &a, const int &b) { return (a >= b ? a : b); }

int n, q;
int a[maxn];
int top, stk[maxq];
db f[maxq][maxq];
vector<int> g[maxq];

namespace ST
{
    int lg[maxn];
    int f[maxn][lg_sz];

    void build(int *a, int n)
    {
        lg[0] = lg[1] = 0;
        for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
        for (int i = 1; i <= n; i++) f[i][0] = a[i];
        for (int j = 1; (1 << j) <= n; j++)
            for (int i = 1; i + (1 << j) - 1 <= n; i++)
                f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
    }

    int query(int l, int r)
    {
        int k = lg[r - l + 1];
        return max(f[l][k], f[r - (1 << k) + 1][k]);
    }
}

void dfs(int u)
{
    // printf("dfs %d\n", u);
    for (int v : g[u]) dfs(v);
    for (int i = 0; i <= q; i++)
    {
        db curp = 1.0 - rg[u].p;
        for (int v : g[u]) curp *= f[v][min(q, i + rg[u].mxv - rg[v].mxv)];
        f[u][i] += curp;
    }
    for (int i = 1; i <= q; i++)
    {
        db curp = rg[u].p;
        for (int v : g[u]) curp *= f[v][min(q, i + rg[u].mxv - rg[v].mxv - 1)];
        f[u][i] += curp;
    }
}

int main()
{
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    ST::build(a, n);
    for (int i = 1; i <= q; i++)
    {
        scanf("%d%d%lf", &rg[i].l, &rg[i].r, &rg[i].p);
        rg[i].mxv = ST::query(rg[i].l, rg[i].r);
    }
    rg[++q] = (item){1, n, ST::query(1, n), 0};
    sort(rg + 1, rg + q + 1);
    stk[top = 1] = 1;
    for (int i = 2; i <= q; i++)
    {
        while (top && (rg[stk[top]].r < rg[i].l)) top--;
        g[stk[top]].push_back(i);
        stk[++top] = i;
    }
    dfs(1);
    db ans = 0;
    // for (int i = 0; i <= q; i++, puts(""))
    //     for (int j = 0; j <= q; j++)
    //         printf("debug f[%d][%d] = %.10lf\n", i, j, f[i][j]);
    for (int i = 0; i <= q; i++) ans += (f[1][i] - (i ? f[1][i - 1] : 0)) * (rg[1].mxv + i);
    printf("%.10lf\n", ans);
    return 0;
}
posted @ 2023-03-02 21:31  kymru  阅读(23)  评论(0编辑  收藏  举报