Loading

【题解】P2605 [ZJOI2010]基站选址

题意

在数轴的正半轴上有 \(N\) 个坐标为整数的点,第 \(i\) 个点的位置为 \(D_i\)。当且仅当在 \([D_i - S_i, D_i + S_i]\) 中存在一个被选中的点时,称点 \(i\) 被覆盖。现在可以选中 \(K\) 个点,选中第 \(i\) 个点的代价为 \(C_i\)。对于点 \(i\),若其未被覆盖,则需要额外付出 \(W_i\) 的代价。试求代价之和的最小值。

\(N \leq 2 \times 10^4, K \leq \min(10^2, N), D_i, S_i \leq 10^9, C_i, W_i \leq 10^4\)

思路

线段树优化 dp。

显然可以考虑一个朴素的 dp。如果设 \(f[i][j]\) 表示前 \(i\) 个点中选择 \(j\) 个点的最小代价和,则需要枚举最后一个被选中的点,时间复杂度不优。考虑设 \(f[i][j]\) 表示前 \(i\) 个点中,选择 \(j\) 个点且选择点 \(i\) 的最小代价和。

简单地,有状态转移方程:

\(f[i][j] = \min[f[k][j - 1] + cost(k, i), k \in [1, i - 1]\)

其中 \(cost(k, i)\) 表示 \([k + 1, i - 1]\) 中无法被覆盖的点的代价和。

暴力求 \(cost(k, i)\) 的复杂度为 \(O(N)\),总复杂度为 \(O(N^2K)\),于是考虑快速求 \(cost(k, i)\)

容易发现,对于每一个点 \(x\),必然都存在且仅存在一个连续的区间 \([l_x, r_x]\),使得从该区间中任意选择一个点都可以覆盖点 \(x\)。考虑二分预处理出该区间。

不妨考虑 \(i\) 向右移动 \(1\) 后的影响。若影响转移,由于 \(f[k][j - 1]\) 为定值,因此必然为 \(cost(k, i)\) 改变。对于 \([k + 1, i - 1]\) 中的任意一点 \(p\),如果 \(k\) 可以覆盖 \(p\),那么无论 \(i\)\([i, n]\) 中取何值,\(p\) 都必然不会对 \(cost(k, i)\) 产生贡献。因此变化只有可能为 \(i\) 向右移动后,原本 \([k + 1, i - 1]\) 只有 \(i\) 可以覆盖的某些点无法被覆盖。

倒推回去,如果 \([k + 1, i - 1]\) 中存在只能被 \(i\) 覆盖而无法被 \(i + 1\) 覆盖的点,令其为 \(x\),则必然有 \(k < l_x, r_x = i\)。因此 \(i\) 向右移动时,我们只需要考虑每一个 \(r_x = i\) 的点 \(x\),然后令从 \([f[1][j - 1], f[l_x - 1][j - 1]]\) 转移的代价增加 \(w_x\),代表此时点 \(x\) 无法被覆盖即可。

我们发现上面的操作实质上是区间加法,区间求最小值,于是考虑用线段树维护。具体地,用线段树维护 \(f[k][j - 1] + cost(k, i)\) 的区间最小值。预处理出 \(f[1][i]\) 的答案后,从 \(j = 2\) 开始枚举,每次建树时令 \(x\) 的权值为 \(f[x][j - 1]\),然后按照上面的过程操作。

为方便地求出答案,可以考虑在 \(+\infty\) 位置虚拟一个点 \(x\),不选中该点需要的代价为 \(+\infty\),选中该点的代价和该点可以覆盖的范围均为 \(0\)。令它处于下标 \(n + 1\),那么 \(\min(f[n + 1][k]), k \in [1, K]\) 即为答案。

时间复杂度 \(O(NK \log N)\)

代码

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

const int maxn = 2e4 + 5;
const int maxk = 1e2 + 5;
const int inf = 0x3f3f3f3f;

struct node
{
    int l, r, val, lazy;
} tree[maxn << 2];

int n, k;
int st[maxn], ed[maxn], f[maxn];
int d[maxn], c[maxn], s[maxn], w[maxn];
vector<int> pos[maxn];

inline int read()
{
    int res = 0, flag = 1;
    char ch = getchar();
    while ((ch < '0') || (ch > '9'))
    {
        if (ch == '-') flag = -1;
        ch = getchar();
    }
    while ((ch >= '0') && (ch <= '9'))
    {
        res = res * 10 + ch - '0';
        ch = getchar();
    }
    return res * flag;
}

inline int min(int a, int b) { return (a <= b ? a : b); }

inline void push_up(int k) { tree[k].val = min(tree[k << 1].val, tree[k << 1 | 1].val); }

inline void push_down(int k)
{
    tree[k << 1].val += tree[k].lazy;
    tree[k << 1 | 1].val += tree[k].lazy;
    tree[k << 1].lazy += tree[k].lazy;
    tree[k << 1 | 1].lazy += tree[k].lazy;
    tree[k].lazy = 0;
}

inline void build(int k, int l, int r)
{
    tree[k].l = l;
    tree[k].r = r;
    tree[k].lazy = 0;
    if (l == r)
    {
        tree[k].val = f[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(k << 1, l, mid);
    build(k << 1 | 1, mid + 1, r);
    push_up(k);
}

inline void update(int k, int l, int r, int w)
{
    if ((tree[k].l >= l) && (tree[k].r <= r))
    {
        tree[k].val += w;
        tree[k].lazy += w;
        return;
    }
    push_down(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    if (l <= mid) update(k << 1, l, r, w);
    if (r > mid) update(k << 1 | 1, l, r, w);
    push_up(k);
}

inline int query(int k, int l, int r)
{
    if ((tree[k].l >= l) && (tree[k].r <= r)) return tree[k].val;
    push_down(k);
    int mid = (tree[k].l + tree[k].r) >> 1, res = inf;
    if (l <= mid) res = min(res, query(k << 1, l, r));
    if (r > mid) res = min(res, query(k << 1 | 1, l, r));
    return res;
}

int main()
{
    int ans = inf;
    n = read(), k = read();
    for (int i = 2; i <= n; i++) d[i] = read();
    for (int i = 1; i <= n; i++) c[i] = read();
    for (int i = 1; i <= n; i++) s[i] = read();
    for (int i = 1; i <= n; i++) w[i] = read();
    n++, k++;
    d[n] = w[n] = inf;
    for (int i = 1; i <= n; i++)
    {
        st[i] = lower_bound(d + 1, d + n + 1, d[i] - s[i]) - d;
        ed[i] = upper_bound(d + 1, d + n + 1, d[i] + s[i]) - d - 1;
        pos[ed[i]].push_back(i);
    }
    int cur = 0;
    for (int i = 1; i <= n; i++)
    {
        f[i] = c[i] + cur;
        for (int p : pos[i]) cur += w[p];
    }
    for (int j = 2; j <= k; j++)
    {
        build(1, 1, n);
        for (int i = 1; i <= n; i++)
        {
            if (i > 1) f[i] = query(1, 1, i - 1) + c[i];
            for (int p : pos[i])
                if (st[p] > 1) update(1, 1, st[p] - 1, w[p]);
        }
        ans = min(ans, f[n]);
    }
    printf("%d\n", ans);
    return 0;
}
posted @ 2022-08-09 23:03  kymru  阅读(40)  评论(0编辑  收藏  举报