Luogu 3267 [JLOI2016/SHOI2016]侦察守卫

以后要记得复习鸭

BZOJ 4557

大佬的博客

状态十分好想,设$f_{x, i}$表示以覆盖完$x$为根的子树后还能向上覆盖$i$层的最小代价,$g_{x, i}$表示以$x$为根的子树下深度为$i$还没有被覆盖的最小代价。

那么对于每一个关键点,有初态:   $f_{x,0} = g_{x, 0} = val_x$。

对于不是关键点的点,有:$f_{x, i} = val_x$   $0 \leq i \leq d$   $f_{x, d + 1} = inf$。

然后就不会了

感觉关键是维护$f_x$和$g_x$的单调性,因为$f_{x, i}$一定不会大于$f_{x, i + 1}$,而$g_{x, i}$一定不会大于$g_{x, i - 1}$。

我们考虑对$x$的所有儿子$y$分开考虑,对于每一个$y$,有方程

    $f_{x, i} = min(f_{x, i} + g_{y, i}, g_{x, i + 1} + f_{y, i + 1})$。

前者表示在之前的计算中已经取了一个子结点能覆盖到$x$,现在这个结点只需要取$g_{y, i}$即可,后者表示取这个结点的$f_{y, i + 1}$,而剩下的代价最小为$g_{x, i + 1}$。

记得从上到下枚举$i$。

然后再扫一遍维护$f_{x}$的单调性。

接着考虑计算$g$,有$g_{x, 0} = f_{x, 0}$。

对于每一个$y$,有$g_{x, i} += g_{y, i - 1}$。

然后从下到上扫一遍维护$g_x$的单调性。

时间复杂度$O(nd)$。

Code:

#include <cstdio>
#include <cstring>
using namespace std;

const int N = 5e5 + 5;
const int M = 22;
const int inf = 1 << 30;

int n, m, d, a[N], tot = 0, head[N], f[N][M], g[N][M]; 
bool flag[N];

struct Edge {
    int to, nxt;
} e[N << 1];

inline void add(int from, int to) {
    e[++tot].to = to;
    e[tot].nxt = head[from];
    head[from] = tot;
}

inline void read(int &X) {
    X = 0; char ch = 0; int op = 1;
    for(; ch > '9' || ch < '0'; ch = getchar())
        if(ch == '-') op = -1;
    for(; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

inline int min(int x, int y) {
    return x > y ? y : x;
}

inline void chkMin(int &x, int y) {
    if(y < x) x = y;
}

void dfs(int x, int fat) {
    if(flag[x]) g[x][0] = f[x][0] = a[x];
    for(int i = 1; i <= d; i++) f[x][i] = a[x];
    f[x][d + 1] = inf;

    for(int i = head[x]; i; i = e[i].nxt) {
        int y = e[i].to;
        if(y == fat) continue;

        dfs(y, x);

        for(int j = d; j >= 0; j--) 
            f[x][j] = min(f[x][j] + g[y][j], g[x][j + 1] + f[y][j + 1]);
        for(int j = d; j >= 0; j--) chkMin(f[x][j], f[x][j + 1]);

        g[x][0] = f[x][0];
        for(int j = 1; j <= d; j++) g[x][j] += g[y][j - 1];
        for(int j = 1; j <= d; j++) chkMin(g[x][j], g[x][j - 1]);
    }
}

int main() {
    read(n), read(d);
    for(int i = 1; i <= n; i++) read(a[i]);
    read(m);
    for(int x, i = 1; i <= m; i++) {
        read(x);
        flag[x] = 1;
    }
    for(int x, y, i = 1; i < n; i++) {
        read(x), read(y);
        add(x, y), add(y, x);
    }

    dfs(1, 0);

    printf("%d\n", f[1][0]);
    return 0;
}
View Code

 

posted @ 2018-10-15 21:02  CzxingcHen  阅读(184)  评论(0编辑  收藏  举报