洛谷 P3267 [JLOI2016/SHOI2016]侦察守卫(树形dp)

题面

luogu

题解

树形\(dp\)

\(f[x][y]表示x的y层以下的所有点都已经覆盖完,还需要覆盖上面的y层的最小代价。\)
\(g[x][y]表示x子树中所有点都已经覆盖完,并且x还能向上覆盖y层的最小代价。\)

对于 \(u->v\), \(u\)\(v\)的父亲:
\(g[u][j] = min(g[u][j]+f[v][j], g[v][j+1]+f[u][j+1])\)
\(f[u][j] = Σf[v][j-1]\)

\(g[u][j] = min(g[u][j], g[u][j+1])\)
\(f[u][j] = min(f[u][j], f[u][j-1])\)

Code

#include<bits/stdc++.h>

#define LL long long
#define RG register

using namespace std;
template<class T> inline void read(T &x) {
    x = 0; RG char c = getchar(); bool f = 0;
    while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
    while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
    x = f ? -x : x;
    return ;
}
template<class T> inline void write(T x) {
    if (!x) {putchar(48);return ;}
    if (x < 0) x = -x, putchar('-');
    int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
    for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
const int N = 500010, INF = 1e9;
int n, d, w[N];
struct node {
    int to, next;
}G[N<<1];
int last[N], gl;
bool vis[N];
void add(int x, int y) {
    G[++gl] = (node) {y, last[x]};
    last[x] = gl;
}
int f[N][22], g[N][22];
void dfs(int u, int fa) {
    if (vis[u]) f[u][0] = g[u][0] = w[u];
    for (int i = 1; i <= d; i++) g[u][i] = w[u];
    g[u][d+1] = INF;
    for (int i = last[u]; i; i = G[i].next) {
        int v = G[i].to;
        if (v == fa) continue;
        dfs(v, u); 
    }
    for (int i = last[u]; i; i = G[i].next) {
        int v = G[i].to;
        if (v == fa) continue;
        for (int j = 0; j <= d; j++) g[u][j] = min(g[u][j]+f[v][j], f[u][j+1]+g[v][j+1]);
        for (int j = d; j >= 0; j--) g[u][j] = min(g[u][j], g[u][j+1]);
        f[u][0] = g[u][0];
        for (int j = 1; j <= d; j++) f[u][j] += f[v][j-1];
        for (int j = 1; j <= d; j++) f[u][j] = min(f[u][j], f[u][j-1]);
    }
    return ;
}

int main() {
    read(n); read(d);
    for (int i = 1; i <= n; i++) read(w[i]);
    int m; read(m);
    for (int i = 1; i <= m; i++) {
        int x; read(x);
        vis[x] = 1;
    }
    for (int i = 1; i < n; i++) {
        int x, y; read(x); read(y);
        add(x, y); add(y, x);
    }
    dfs(1, 0);
    printf("%d\n", g[1][0]);
    return 0;
}

posted @ 2019-01-09 09:29  zzy2005  阅读(151)  评论(0编辑  收藏  举报