Description

有一些高度为 h 的树在数轴上。每次选择剩下的树中最左边或是最右边的树推倒(各 50% 概率),往左倒有 p 的概率,往右倒 1-p。

一棵树倒了,如果挨到的另一棵树与该数的距离严格小于h,那么它也会往同方向倒。

问所有树都被推倒后的期望覆盖长度?

要注意的一点是每棵树占的是一个点,相邻点之间只有一段距离!

Solution

我们定义 \(\mathtt{f[x][y][l][r]}\) 为:边界为 x 与 y 的推倒期望长度。

关于 l 与 r 则有:

  • \(\mathtt{l==0}\):x 向左推倒没有限制。(所谓限制就是 \(\mathtt{x-1}\) 那一位向右推倒并与 x 向左推倒有相交的区域)
  • \(\mathtt{l==1}\):x 向左推倒有限制。
  • \(\mathtt{r==0}\):y 向右推倒有限制。
  • \(\mathtt{r==1}\):y 向右推倒没有限制。

我们再记录 \(\mathtt{L[i]}\)\(\mathtt{R[i]}\) 分别表示 i 向左推倒能殃及到第几号树,向右推倒能殃及第几号树,这个 \(\mathtt{O(n)}\) 就可实现。

然后记忆化搜索就行了。这个时间复杂度应该是 \(\mathtt{O(n^2)}\) 的?(左边每向右推进一棵树,右边粗略有 \(\mathtt{O(n)}\) 的树可枚举)

Code

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

const int N = 2005;

int n, h, pos[N], L[N], R[N];
double f[N][N][2][2], p;

int read() {
    int x = 0, f = 1; char s;
    while((s = getchar()) < '0' || s > '9') if(s == '-') f = -1;
    while(s >= '0' && s <= '9') {x = (x << 1) + (x << 3) + (s ^ 48); s = getchar();}
    return x * f;
}

int check(const int op, const int x, const int dir) {
    if(! op) {
        if(! dir) return min(pos[x] - pos[x - 1], h);
        return min(max(pos[x] - pos[x - 1] - h, 0), h);
    }
    else {
        if(! dir) return min(max(pos[x + 1] - pos[x] - h, 0), h);
        return min(pos[x + 1] - pos[x], h);
    }
}

double dfs(const int x, const int y, const int l, const int r) {
    if(x > y) return 0;
    if(f[x][y][l][r]) return f[x][y][l][r];
    double &ans = f[x][y][l][r];
    ans += 0.5 * p * (dfs(x + 1, y, 0, r) + check(0, x, l));
    if(R[x] + 1 <= y) ans += 0.5 * (1 - p) * (dfs(R[x] + 1, y, 1, r) + pos[R[x]] - pos[x] + h);
    else ans += 0.5 * (1 - p) * (pos[y] - pos[x] + check(1, y, r));
    ans += 0.5 * (1 - p) * (dfs(x, y - 1, l, 1) + check(1, y, r));
    if(x <= L[y] - 1) ans += 0.5 * p * (dfs(x, L[y] - 1, l, 0) + pos[y] - pos[L[y]] + h);
    else ans += 0.5 * p * (pos[y] - pos[x] + check(0, x, l));
    return ans;
}

int main() {
    n = read(), h = read(), scanf("%lf", &p);
    for(int i = 1; i <= n; ++ i) pos[i] = read();
    sort(pos + 1, pos + n + 1);
    pos[0] = pos[1] - h; pos[n + 1] = pos[n] + h;
    L[1] = 1; R[n] = n;
    for(int i = 2; i <= n; ++ i)
        if(pos[i] - pos[i - 1] < h) L[i] = L[i - 1];
        else L[i] = i;
    for(int i = n - 1; i >= 1; -- i)
        if(pos[i + 1] - pos[i] < h) R[i] = R[i + 1];
        else R[i] = i;
    printf("%.10f\n", dfs(1, n, 0, 1));
    return 0;
}
posted on 2020-04-12 15:48  Oxide  阅读(124)  评论(0编辑  收藏  举报