【题解】CF494C Helping People
思路
ST 表 + 树形 dp.
首先注意到 \(q \leq 5 \times 10^3\) 并且区间互不相交,这实际上给出了一个很强的结论:
考虑把每个区间向它的子区间连边,得到的一定是一个森林。
只要我们钦定存在一个概率为 \(0\) 的操作 \([1, n]\),那么最终得到的一定是一棵树。
构造这棵树可以考虑使用栈。
首先将所有区间优先按左端点升序排列,左端点相同则按右端点降序排列。此时任意区间的子区间均在其后。
使用一个栈存下当前区间的所有父区间。每右移一位就将所有右端点在当前区间左端点左侧的区间移除。
因为左端点升序,所以该区间一定是当前区间最小的父区间。
然后将当前区间入栈。
建树时间复杂度 \(O(n \log n)\)
考虑到 \(E(\max) \neq \max(E)\),直接将期望设进状态是不好做的。
只能考虑使用期望的定义,考虑每种最大值出现的概率。
钦定最大值求概率依旧很困难,考虑弱化成钦定最大值不超过某定值,这样最后的答案容斥一下就行。
于是考虑令 \(f[u][i]\) 为结点 \(u\) 子树中的操作完成后,该结点对应区间的最大值不超过 \(i\) 的概率。
转移容易推导:\(f[u][i] = p_u \prod\limits_{v \in son(u)} f[v][i - 1] + (1 - p_u) \prod\limits_{v \in son(u)} f[v][i]\).
但是第二维的值域是 \([1, 10^9]\),必须优化一下 /fn
注意到每次操作至多使最大值增加 \(1\),也就是 \(q\) 次操作后最大值的增量不超过 \(q\),于是令第二维变成在原有区间最大值基础上的增量就行。
令 \(w(u)\) 为结点 \(u\) 对应区间的初始最大值,有转移:
\(f[u][i] = p_u \prod\limits_{v \in son(u)} f[v][i + w(u) - w(v) - 1] + (1 - p_u) \prod\limits_{v \in son(u)} f[v][i + w(u) - w(v)]\).
特别地,当 \(i = 0\) 时只取后一半转移。
答案为 \(f[1][0] \cdot w(1) + \sum\limits_{i = 1}^q (f[1][i] - f[1][i - 1]) (w(1) + i)\).
于是可以在 \(O(n \log n + q^2)\) 的复杂度内解决此题。
代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef double db;
const int maxn = 1e5 + 5;
const int maxq = 5e3 + 5;
const int lg_sz = 20;
struct item
{
int l, r, mxv; db p;
bool operator < (const item& rhs) const
{
if (l != rhs.l) return (l < rhs.l);
return (r > rhs.r);
}
} rg[maxq];
inline int max(const int &a, const int &b) { return (a >= b ? a : b); }
int n, q;
int a[maxn];
int top, stk[maxq];
db f[maxq][maxq];
vector<int> g[maxq];
namespace ST
{
int lg[maxn];
int f[maxn][lg_sz];
void build(int *a, int n)
{
lg[0] = lg[1] = 0;
for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for (int i = 1; i <= n; i++) f[i][0] = a[i];
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i + (1 << j) - 1 <= n; i++)
f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
int query(int l, int r)
{
int k = lg[r - l + 1];
return max(f[l][k], f[r - (1 << k) + 1][k]);
}
}
void dfs(int u)
{
// printf("dfs %d\n", u);
for (int v : g[u]) dfs(v);
for (int i = 0; i <= q; i++)
{
db curp = 1.0 - rg[u].p;
for (int v : g[u]) curp *= f[v][min(q, i + rg[u].mxv - rg[v].mxv)];
f[u][i] += curp;
}
for (int i = 1; i <= q; i++)
{
db curp = rg[u].p;
for (int v : g[u]) curp *= f[v][min(q, i + rg[u].mxv - rg[v].mxv - 1)];
f[u][i] += curp;
}
}
int main()
{
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
ST::build(a, n);
for (int i = 1; i <= q; i++)
{
scanf("%d%d%lf", &rg[i].l, &rg[i].r, &rg[i].p);
rg[i].mxv = ST::query(rg[i].l, rg[i].r);
}
rg[++q] = (item){1, n, ST::query(1, n), 0};
sort(rg + 1, rg + q + 1);
stk[top = 1] = 1;
for (int i = 2; i <= q; i++)
{
while (top && (rg[stk[top]].r < rg[i].l)) top--;
g[stk[top]].push_back(i);
stk[++top] = i;
}
dfs(1);
db ans = 0;
// for (int i = 0; i <= q; i++, puts(""))
// for (int j = 0; j <= q; j++)
// printf("debug f[%d][%d] = %.10lf\n", i, j, f[i][j]);
for (int i = 0; i <= q; i++) ans += (f[1][i] - (i ? f[1][i - 1] : 0)) * (rg[1].mxv + i);
printf("%.10lf\n", ans);
return 0;
}