Luogu 3267 [JLOI2016/SHOI2016]侦察守卫
以后要记得复习鸭
BZOJ 4557
状态十分好想,设$f_{x, i}$表示以覆盖完$x$为根的子树后还能向上覆盖$i$层的最小代价,$g_{x, i}$表示以$x$为根的子树下深度为$i$还没有被覆盖的最小代价。
那么对于每一个关键点,有初态: $f_{x,0} = g_{x, 0} = val_x$。
对于不是关键点的点,有:$f_{x, i} = val_x$ $0 \leq i \leq d$ $f_{x, d + 1} = inf$。
然后就不会了
感觉关键是维护$f_x$和$g_x$的单调性,因为$f_{x, i}$一定不会大于$f_{x, i + 1}$,而$g_{x, i}$一定不会大于$g_{x, i - 1}$。
我们考虑对$x$的所有儿子$y$分开考虑,对于每一个$y$,有方程
$f_{x, i} = min(f_{x, i} + g_{y, i}, g_{x, i + 1} + f_{y, i + 1})$。
前者表示在之前的计算中已经取了一个子结点能覆盖到$x$,现在这个结点只需要取$g_{y, i}$即可,后者表示取这个结点的$f_{y, i + 1}$,而剩下的代价最小为$g_{x, i + 1}$。
记得从上到下枚举$i$。
然后再扫一遍维护$f_{x}$的单调性。
接着考虑计算$g$,有$g_{x, 0} = f_{x, 0}$。
对于每一个$y$,有$g_{x, i} += g_{y, i - 1}$。
然后从下到上扫一遍维护$g_x$的单调性。
时间复杂度$O(nd)$。
Code:
#include <cstdio> #include <cstring> using namespace std; const int N = 5e5 + 5; const int M = 22; const int inf = 1 << 30; int n, m, d, a[N], tot = 0, head[N], f[N][M], g[N][M]; bool flag[N]; struct Edge { int to, nxt; } e[N << 1]; inline void add(int from, int to) { e[++tot].to = to; e[tot].nxt = head[from]; head[from] = tot; } inline void read(int &X) { X = 0; char ch = 0; int op = 1; for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } inline int min(int x, int y) { return x > y ? y : x; } inline void chkMin(int &x, int y) { if(y < x) x = y; } void dfs(int x, int fat) { if(flag[x]) g[x][0] = f[x][0] = a[x]; for(int i = 1; i <= d; i++) f[x][i] = a[x]; f[x][d + 1] = inf; for(int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; if(y == fat) continue; dfs(y, x); for(int j = d; j >= 0; j--) f[x][j] = min(f[x][j] + g[y][j], g[x][j + 1] + f[y][j + 1]); for(int j = d; j >= 0; j--) chkMin(f[x][j], f[x][j + 1]); g[x][0] = f[x][0]; for(int j = 1; j <= d; j++) g[x][j] += g[y][j - 1]; for(int j = 1; j <= d; j++) chkMin(g[x][j], g[x][j - 1]); } } int main() { read(n), read(d); for(int i = 1; i <= n; i++) read(a[i]); read(m); for(int x, i = 1; i <= m; i++) { read(x); flag[x] = 1; } for(int x, y, i = 1; i < n; i++) { read(x), read(y); add(x, y), add(y, x); } dfs(1, 0); printf("%d\n", f[1][0]); return 0; }