P5024 [NOIP2018 提高组] 保卫王国

P5024 NOIP2018 提高组 保卫王国

f[u][0/1]表示只在以u为根节点的子树(包含u)中最小花费(u:不选/选)
g[u][0/1]表示不在以u为根节点的子树(不含u)中最小花费(u:不选/选)
w[u][i][0/1]表示u,u<<i不选/选,以u<<i(含)为子树除去以u为根节点的子树(不含)的最小花费
使用类似倍增求LCA的方法计算出答案

点击查看代码

#include <stdio.h>
#include <string.h>
const int N = 1e5 + 5, M = N << 1, logN = 17;
typedef long long LL; typedef signed char byte;
const LL inf = 0x3f3f3f3f3f3f3f3fLL;
LL min(LL x, LL y) { return x < y ? x : y; }
int n, m, p[N], h[N], e[M], nxt[M], idx;
LL f[N][2], g[N][2], w[N][logN][2][2];
int fa[N][logN], dep[N];
void add(int a, int b) { e[++ idx] = b, nxt[idx] = h[a], h[a] = idx; }
void dfs1(int u) { // 求出fa和f数组
	for(byte i = 1; i < logN && fa[u][i - 1]; i ++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
	f[u][1] = p[u];
	for(int i = h[u], v; i; i = nxt[i]) {
		if((v = e[i]) == fa[u][0]) continue;
		dep[v] = dep[u] + 1, fa[v][0] = u, dfs1(v);
		f[u][1] += min(f[v][0], f[v][1]), f[u][0] += f[v][1]; // 类似没有上司的舞会
	}
}
void dfs2(int u) { // 求出g和w[][0]数组
	for(int i = h[u], v; i; i = nxt[i]) {
		if((v = e[i]) == fa[u][0]) continue;
		g[v][0] = g[u][1] + f[u][1] - min(f[v][0], f[v][1]); // 注意计算u不是v的儿子的贡献
		g[v][1] = min(g[v][0], g[u][0] + f[u][0] - f[v][1]);
		w[v][0][0][1] = w[v][0][1][1] = f[u][1] - min(f[v][0], f[v][1]);
		w[v][0][1][0] = f[u][0] - f[v][1];
		dfs2(v);
	}
}
LL solve(int x, int tx, int y, int ty) {
	if(dep[x] < dep[y]) x ^= y ^= x ^= y, tx ^= ty ^= tx ^= ty;
	static LL nx[2], ny[2], mx[2], my[2]; // nx,ny分别表示x,y在上以的时候的权值和;mx,my为备份数组
	nx[tx] = f[x][tx], nx[!tx] = inf;
	for(byte k = logN - 1; k >= 0; k --)
		if(dep[fa[x][k]] >= dep[y]) {
			mx[0] = nx[0], mx[1] = nx[1], nx[0] = nx[1] = inf;
			for(byte i = 0, j; i < 2; i ++) for(j = 0; j < 2; j ++)
				nx[j] = min(nx[j], w[x][k][i][j] + mx[i]);
			x = fa[x][k];
		}
	if(x == y) return nx[ty] + g[y][ty];
	ny[ty] = f[y][ty], ny[!ty] = inf;
	for(byte k = logN - 1; k >= 0; k --)
		if(fa[x][k] != fa[y][k]) {
			mx[0] = nx[0], mx[1] = nx[1], nx[0] = nx[1] = inf, my[0] = ny[0], my[1] = ny[1], ny[0] = ny[1] = inf;
			for(byte i = 0, j; i < 2; i ++) for(j = 0; j < 2; j ++)
				nx[j] = min(nx[j], w[x][k][i][j] + mx[i]), ny[j] = min(ny[j], w[y][k][i][j] + my[i]);
			x = fa[x][k], y = fa[y][k];
		}
	int lca = fa[x][0];
	return min(
		nx[1] + ny[1] + g[lca][0] + f[lca][0] - f[x][1] - f[y][1], // 不选lca和选lca
		min(nx[0], nx[1]) + min(ny[0], ny[1]) + g[lca][1] + f[lca][1] - min(f[x][0], f[x][1]) - min(f[y][0], f[y][1])
	);
}
int main() {
	scanf("%d%d%*s", &n, &m);
	for(int i = 1; i <= n; i ++) scanf("%d", p + i);
	for(int i = 1, a, b; i < n; i ++) scanf("%d%d", &a, &b), add(a, b), add(b, a);
	for(int i = 1; i <= n; i ++) memset(w[i], 0x3f, sizeof(w[i]));
	dep[1] = 1, dfs1(1), dfs2(1);
	for(int u = 1; u <= n; u ++)
		for(byte t = 1; t < logN && fa[u][t]; t ++) // 枚举点的状态倍增求出其余w
			for(byte i = 0, j, k; i < 2; i ++) for(j = 0; j < 2; j ++) for(k = 0; k < 2; k ++)
				w[u][t][i][j] = min(w[u][t][i][j], w[u][t - 1][i][k] + w[fa[u][t - 1]][t - 1][k][j]);
	for(int i = 0, x, tx, y, ty; i < m; i ++) {
		scanf("%d%d%d%d", &x, &tx, &y, &ty);
		if(!tx && !ty && (fa[x][0] == y || fa[y][0] == x)) puts("-1"); // 无解
		else printf("%lld\n", solve(x, tx, y, ty));
	}
	return 0;
}
posted @ 2022-09-28 10:09  azzc  阅读(29)  评论(0编辑  收藏  举报