Loading

P9433 [NAPC-#1] Stage5 - Conveyors (lca 维护树上路径)

P9433 [NAPC-#1] Stage5 - Conveyors

lca 维护树上路径

但是这题不是难在这里,考察的是分析问题答案构成的能力。我们可以从数据范围出发。

\(s = t,k=n\)

每条边都要走两遍,显然是树上所有边权和 \(\times 2\)

\(k = n\)

可以构造一种走法,使得 \(t\) 先到 \(s\),按照上面的走法走完之后回到 \(s\),最后走回 \(t\),这样做答案是 树上所有边权和 \(\times 2-dis(s,t)\)

无特殊限制

那么经过上面的思考,可以知道我们一定要走的边一定有包含 \(k\) 个关键点的最小连通块,可以预处理出来边权和 \(sum\)。这时候考虑 \(s\)\(t\) 对答案的影响。如果 \(s\) 在连通块内,那么没影响,如果不在,那么我们要走最短的距离到达连通块中,然后继续用之前的方法走到 k 个点\(t\) 同理。

那么如何处理每个点到连通块的最短距离呢?当然可以建源点,跑最短路求出。但是有更简单的做法:显然连通块中一定有关键点,如果我们将其中一个关键点作为根,那么连通块一定是树上最上面的一片,这时候剩下的点都在连通块之下,可以直接倍增往上跳到第一个在连通块中的点。

求出 \(s\)\(t\) 的最近点 \(u\)\(v\) 后,最后答案可以表示为 \(sum\times 2-dis(u,v)+dis(s,u)+dis(v,t)\)

复杂度 \(O(n\log n)\)

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back

typedef long long i64;
const int N = 1e5 + 10;
int n, q, k, rt, sum;
int dep[N], anc[N][20], vis[N], dis[N];
std::vector<pii> e[N];
void dfs(int u, int fa) {
	dep[u] = dep[fa] + 1;
	anc[u][0] = fa;
	for(int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
	for(auto v : e[u]) {
		if(v.fi == fa) continue;
		dis[v.fi] = dis[u] + v.se;
		dfs(v.fi, u);
		if(vis[v.fi]) sum += v.se;
		vis[u] |= vis[v.fi];
	}
}
int lca(int u, int v) {
	if(dep[u] < dep[v]) std::swap(u, v);
	for(int i = 19; i >= 0; i--) if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
	if(u == v) return u;
	for(int i = 19; i >= 0; i--) if(anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
	return anc[u][0];
}
int get(int u) {
	if(vis[u]) return u;
	for(int i = 19; i >= 0; i--) {
		if(anc[u][i] && !vis[anc[u][i]]) u = anc[u][i];
	}
	return anc[u][0];
}
int dist(int u, int v) {
	int rt = lca(u, v);
	return dis[u] + dis[v] - dis[rt] * 2;
}
void Solve() {
	std::cin >> n >> q >> k;

	for(int i = 1; i < n; i++) {
		int u, v, w;
		std::cin >> u >> v >> w;
		e[u].pb({v, w}), e[v].pb({u, w});
	}
	for(int i = 1; i <= k; i++) {
		int x;
		std::cin >> x;
		vis[x] = 1;
		rt = x;
	}
	dfs(rt, 0);
	while(q--) {
		int s, t;
		std::cin >> s >> t;
		int u = get(s), v = get(t);
		std::cout << sum * 2 - dist(u, v) + (dis[s] - dis[u]) + (dis[t] - dis[v]) << "\n";
	}
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	Solve();

	return 0;
}
posted @ 2024-04-12 17:49  Fire_Raku  阅读(44)  评论(0编辑  收藏  举报