[LOJ][JLOI / SHOI2016]侦查守卫

链接

用 f[i][j] 表示从当前节点 i 向下 j 距离内没有被覆盖
用 g[i][j] 表示从当前节点 i 的子树内被完全覆盖 ,并且再向上覆盖 j 距离

注意边界条件 , 需要覆盖的点上存在 f[i][0] = g[i][0] = w[i]


#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<map>
#include<queue>
#include<vector>
#include<cstdio>

using namespace std;
typedef long long ll;
const int N = 500005 , Mod = 998244353 , G = 3 , INF = 0x3f3f3f3f;
double PI = acos(-1);

struct Virt{
    double x , y;
    Virt(double _x = 0.0 , double _y = 0.0):x(_x) , y(_y){}
};

Virt operator + (Virt x , Virt y){return Virt(x.x + y.x , x.y + y.y);}
Virt operator - (Virt x , Virt y){return Virt(x.x - y.x , x.y - y.y);}
Virt operator * (Virt x , Virt y){return Virt(x.x * y.x - x.y * y.y , x.x * y.y + x.y * y.x);}

int read(){
	int x = 0 , f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -1 ; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0'; ch = getchar();}
	return x * f;
}

int n , m , d , num , head[N];
int f[N][21] , g[N][21] , w[N];
bool vis[N];

struct edge{
	int pos , nx;
}e[N << 1];

void add(int  u , int v){
	e[++num].pos = v; e[num].nx = head[u]; head[u] = num;
	e[++num].pos = u; e[num].nx = head[v]; head[v] = num;
}

void dfs(int now , int pre){
	if(vis[now]) f[now][0] = g[now][0] = w[now];
	for(int i = 1 ; i <= d ; i ++) g[now][i] = w[now];
	g[now][d + 1] = INF;
	for(int i = head[now] ; i ; i = e[i].nx){
		if(e[i].pos == pre) continue;
		dfs(e[i].pos , now);
		for(int j = d ; j >= 0 ; j--)
			g[now][j] = min(g[now][j] + f[e[i].pos][j] , g[e[i].pos][j + 1] + f[now][j + 1]);
		for(int j = d ; j >= 0 ; j--)
			g[now][j] = min(g[now][j] , g[now][j + 1]);
		f[now][0] = g[now][0];
		for(int j = 1 ; j <= d + 1 ; j++)
			f[now][j] += f[e[i].pos][j - 1];
		for(int j = 1 ; j <= d + 1 ; j++)
			f[now][j] = min(f[now][j - 1] , f[now][j]);
	}
}

int main(){
	int u , v;
	n = read(); d = read();
	for(int i = 1 ; i <= n ; i++) w[i] = read();
	m = read();
	for(int i = 1 ; i <= m ; i++) vis[read()] = true;
	for(int i = 1 ; i < n ; i++){
		u = read(); v = read();
		add(u , v);
	}
	dfs(1 , 0);
	printf("%d\n",f[1][0]);
}

posted @ 2018-04-05 01:58  FranceDisco  阅读(119)  评论(0编辑  收藏  举报