[NOIP2018]保卫王国

倍增写法

#include <bits/stdc++.h>

#define rep(i, a, b) for (int i = a, i##end = b; i <= i##end; ++i)
#define per(i, a, b) for (int i = a, i##end = b; i >= i##end; --i)
#define rep0(i, a) for (int i = 0, i##end = a; i < i##end; ++i)
#define per0(i, a) for (int i = (int)a-1; ~i; --i)
#define chkmax(a, b) a = std::max(a, b)
#define chkmin(a, b) a = std::min(a, b)
#define x first
#define y second
#define pb push_back
#define mp std::make_pair
#define enter putchar('\n')

typedef long long ll;

using std::max;
using std::swap;

int read() {
	int w = 0, f = 1; char c;
	while (!isdigit(c = getchar())) c == '-' && (f == -1);
	while (isdigit(c)) w = w*10+(c^48), c = getchar();
	return w * f;
}

const int N = 100005;
const ll oo = 1ll<<50;

int n, m, p[N];
ll sgm = 0;

struct Edge { int v, nxt; } e[N*2];
int G[N], edges = 0;
void adde(int u, int v) { e[edges++] = (Edge){v, G[u]}; G[u] = edges-1; }

ll f[2][N], g[2][2][18][N], h[18][N]; int d[N];
void dfs1(int u, int fa) {
	for (int i = G[u], v; ~i; i = e[i].nxt) {
		if ((v = e[i].v) == fa) continue;
		dfs1(v, u);
		f[0][u] += max(f[0][v], f[1][v]); f[1][u] += f[0][v];
	}
	f[1][u] += p[u];
}
void dfs2(int u, int fa) {
	d[u] = d[h[0][u] = fa]+1;
	g[0][0][0][u] = g[1][0][0][u] = f[0][fa] - max(f[0][u], f[1][u]);
	g[0][1][0][u] = f[1][fa] - f[0][u];
	g[1][1][0][u] = -oo;
	rep(i, 1, 17) {
		h[i][u] = h[i-1][h[i-1][u]];
		g[0][0][i][u] = max(g[0][0][i-1][u]+g[0][0][i-1][h[i-1][u]], g[0][1][i-1][u]+g[1][0][i-1][h[i-1][u]]);
		g[0][1][i][u] = max(g[0][0][i-1][u]+g[0][1][i-1][h[i-1][u]], g[0][1][i-1][u]+g[1][1][i-1][h[i-1][u]]);
		g[1][0][i][u] = max(g[1][0][i-1][u]+g[0][0][i-1][h[i-1][u]], g[1][1][i-1][u]+g[1][0][i-1][h[i-1][u]]);
		g[1][1][i][u] = max(g[1][0][i-1][u]+g[0][1][i-1][h[i-1][u]], g[1][1][i-1][u]+g[1][1][i-1][h[i-1][u]]);
	}
	for (int i = G[u]; ~i; i = e[i].nxt)
		if (e[i].v != fa) dfs2(e[i].v, u);
}

void jump(ll &f0, ll &f1, int &x, int i) {
	ll g0 = max(f0+g[0][0][i][x], f1+g[1][0][i][x]), g1 = max(f0+g[0][1][i][x], f1+g[1][1][i][x]);
	f0 = g0, f1 = g1; x = h[i][x];
}
ll dp(int x, int a, int y, int b) {
	if (d[x] < d[y]) swap(x, y), swap(a, b);
	ll f0, f1;
	a ? (f0 = -oo, f1 = f[1][x]) : (f0 = f[0][x], f1 = -oo);
	for (int i = 0, j = d[x]-d[y]; j; j >>= 1, i++)
		if (j & 1) jump(f0, f1, x, i);
	if (x == y) {
		b ? (f0 = -oo) : (f1 = -oo);
		for (int i = 0, j = d[x]-1; j; j >>= 1, i++)
			if (j & 1) jump(f0, f1, x, i);
		return max(f0, f1);
	}
	ll g0, g1;
	b ? (g0 = -oo, g1 = f[1][y]) : (g0 = f[0][y], g1 = -oo);
	per0(i, 18)
		if (h[i][x] != h[i][y])	jump(f0, f1, x, i), jump(g0, g1, y, i);
	int u = h[0][x]; ll h0 = f[0][u]-max(f[0][x], f[1][x])-max(f[0][y], f[1][y])+max(f0, f1)+max(g0, g1), h1 = f[1][u]-f[0][x]-f[0][y]+f0+g0;
	for (int i = 0, j = d[u]-1; j; j >>= 1, i++)
		if (j & 1) jump(h0, h1, u, i);
	return max(h0, h1);
}

int main() {
	n = read(), m = read(); read();
	rep(i, 1, n) p[i] = read(), sgm += p[i];
	memset(G, -1, sizeof G);
	rep(i, 2, n) { int u = read(), v = read(); adde(u, v), adde(v, u); }
	dfs1(1, 0), dfs2(1, 0);
	while (m--) {
		int x = read(), a = read(), y = read(), b = read();
		ll tot = dp(x, !a, y, !b);
		printf("%lld\n", tot < 0 ? -1 : sgm-tot);
	}
	return 0;
}

ddp写法

这篇代码鸽了
posted @ 2020-11-03 22:18  AC-Evil  阅读(105)  评论(0编辑  收藏  举报