P9433 [NAPC-#1] Stage5 - Conveyors (lca 维护树上路径)
P9433 [NAPC-#1] Stage5 - Conveyors
lca 维护树上路径
但是这题不是难在这里,考察的是分析问题答案构成的能力。我们可以从数据范围出发。
\(s = t,k=n\)
每条边都要走两遍,显然是树上所有边权和 \(\times 2\)。
\(k = n\)
可以构造一种走法,使得 \(t\) 先到 \(s\),按照上面的走法走完之后回到 \(s\),最后走回 \(t\),这样做答案是 树上所有边权和 \(\times 2-dis(s,t)\)。
无特殊限制
那么经过上面的思考,可以知道我们一定要走的边一定有包含 \(k\) 个关键点的最小连通块,可以预处理出来边权和 \(sum\)。这时候考虑 \(s\) 和 \(t\) 对答案的影响。如果 \(s\) 在连通块内,那么没影响,如果不在,那么我们要走最短的距离到达连通块中,然后继续用之前的方法走到 k 个点;\(t\) 同理。
那么如何处理每个点到连通块的最短距离呢?当然可以建源点,跑最短路求出。但是有更简单的做法:显然连通块中一定有关键点,如果我们将其中一个关键点作为根,那么连通块一定是树上最上面的一片,这时候剩下的点都在连通块之下,可以直接倍增往上跳到第一个在连通块中的点。
求出 \(s\) 和 \(t\) 的最近点 \(u\) 和 \(v\) 后,最后答案可以表示为 \(sum\times 2-dis(u,v)+dis(s,u)+dis(v,t)\)。
复杂度 \(O(n\log n)\)。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back
typedef long long i64;
const int N = 1e5 + 10;
int n, q, k, rt, sum;
int dep[N], anc[N][20], vis[N], dis[N];
std::vector<pii> e[N];
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
anc[u][0] = fa;
for(int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
for(auto v : e[u]) {
if(v.fi == fa) continue;
dis[v.fi] = dis[u] + v.se;
dfs(v.fi, u);
if(vis[v.fi]) sum += v.se;
vis[u] |= vis[v.fi];
}
}
int lca(int u, int v) {
if(dep[u] < dep[v]) std::swap(u, v);
for(int i = 19; i >= 0; i--) if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
int get(int u) {
if(vis[u]) return u;
for(int i = 19; i >= 0; i--) {
if(anc[u][i] && !vis[anc[u][i]]) u = anc[u][i];
}
return anc[u][0];
}
int dist(int u, int v) {
int rt = lca(u, v);
return dis[u] + dis[v] - dis[rt] * 2;
}
void Solve() {
std::cin >> n >> q >> k;
for(int i = 1; i < n; i++) {
int u, v, w;
std::cin >> u >> v >> w;
e[u].pb({v, w}), e[v].pb({u, w});
}
for(int i = 1; i <= k; i++) {
int x;
std::cin >> x;
vis[x] = 1;
rt = x;
}
dfs(rt, 0);
while(q--) {
int s, t;
std::cin >> s >> t;
int u = get(s), v = get(t);
std::cout << sum * 2 - dist(u, v) + (dis[s] - dis[u]) + (dis[t] - dis[v]) << "\n";
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
Solve();
return 0;
}