AcWing 1171. 距离
\(AcWing\) \(1171\). 距离
一、题目描述
给出 \(n\) 个点的一棵树,多次询问两点之间的 最短距离。
注意:
- 边是无向的。
- 所有节点的编号是 \(1,2,…,n\)。
输入格式
第一行为两个整数 \(n\) 和 \(m\)。\(n\) 表示点数,\(m\) 表示询问次数;
下来 \(n−1\) 行,每行三个整数 \(x,y,k\),表示点 \(x\) 和点 \(y\) 之间存在一条边长度为 \(k\);
再接下来 \(m\) 行,每行两个整数 \(x,y\),表示询问点 \(x\) 到点 \(y\) 的最短距离。
树中结点编号从 \(1\) 到 \(n\)。
输出格式
共 \(m\) 行,对于每次询问,输出一行询问结果。
数据范围
\(2≤n≤10^4\)
\(1≤m≤2×10^4\)
\(0<k≤100\)
\(1≤x,y≤n\)
输入样例1:
2 2
1 2 100
1 2
2 1
输出样例1:
100
100
输入样例2:
3 2
1 2 10
3 1 15
1 2
3 2
输出样例2:
10
25
二、倍增算法
此题就是模板基础上的简单扩展,\(x,y\) 到 \(lca(x,y)=z\)的最短距离,可以转化为源点(任意点均可)到
\(Code\)
#include <bits/stdc++.h>
using namespace std;
const int N = 20010, M = 40010;
int n, m;
int f[N][16], depth[N];
int dist[N]; // 距离1号点的距离
// 邻接表
int e[M], h[N], idx, w[M], ne[M];
void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
void bfs() {
// 1号点是源点
depth[1] = 1;
queue<int> q;
q.push(1);
while (q.size()) {
int u = q.front();
q.pop();
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (!depth[v]) {
q.push(v);
depth[v] = depth[u] + 1;
dist[v] = dist[u] + w[i];
f[v][0] = u; // 父亲大人
for (int k = 1; k <= 15; k++) // 记录倍增数组
f[v][k] = f[f[v][k - 1]][k - 1];
}
}
}
}
// 最近公共祖先
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
// 对齐
for (int k = 15; k >= 0; k--)
if (depth[f[a][k]] >= depth[b])
a = f[a][k];
if (a == b) return a;
// 齐步走
for (int k = 15; k >= 0; k--)
if (f[a][k] != f[b][k])
a = f[a][k], b = f[b][k];
// 返回父亲
return f[a][0];
}
int main() {
memset(h, -1, sizeof h);
scanf("%d %d", &n, &m);
int a, b, c;
// n-1条边
for (int i = 1; i < n; i++) {
scanf("%d %d %d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
bfs();
while (m--) {
scanf("%d %d", &a, &b);
int t = lca(a, b);
int ans = dist[a] + dist[b] - dist[t] * 2;
printf("%d\n", ans);
}
return 0;
}
三、\(Tarjan\)算法计算\(LCA\)
本题考察\(LCA\)的\(Tarjan\)算法。\(Tarjan\)算法是一个 离线算法,一次性读入,计算后再一次性输出,算法的时间复杂度是\(O(n + m)\)。
算法原理
设\(x\)和\(y\)的\(LCA\)是\(r\)(暂时假设\(r、x、y\)是三个不同的节点),则\(x\)和\(y\)一定处于以\(r\)为根的不同子树中,并且可以得出:处在\(r\)不同子树中的任意两个节点的\(LCA\)都是\(r\),所以在遍历以\(r\)为根的子树时,只要能够判断出两个节点分处于\(r\)的不同子树中,就可以将这两个节点的\(LCA\)标记为\(r\)。
如何判断\(x\)和\(y\)是否处在\(r\)的不同子树中呢?在对以\(r\)为根的子树做\(dfs\)的过程中,如果\(y\)所在的子树已经遍历完了,之后又遍历到\(x\)时,就可以说明\(x\)和\(y\)不在同一棵子树了。
对树的节点进行状态划分:
- \(0\) :还未遍历到的节点
- \(1\) :该节点已经遍历到了,但是其子树还没有完成遍历回溯完
- \(2\) :该节点以及其子树均已遍历回溯完
注:\(2\) 这个状态在代码实现中被省略,没用上
在\(dfs\)过程中,第一次遍历到\(r\)时,\(r\)的状态转化为\(1\),并且,\(r\)的祖先节点的状态也都是\(1\)。当\(y\)所在的子树全部遍历回溯完后,\(y\)到\(r\)的路径中,除了\(r\)以外的其他节点的状态均是\(2\)。
换言之,\(x\)和\(y\)的\(LCA\)就是\(y\)向上回溯到第一个状态为\(1\)的节点。
\(dfs\)遍历完\(y\)所在的子树并且遍历完\(x\)及其子树时各节点的状态如上图所示。此时,\(x\)的子树刚刚全部遍历回溯完成,然后发现\(y\)的状态是\(2\),于是\(y\)向上回溯,发现了第一个标记为状态\(1\)的\(r\)节点,也就是\(x\)和\(y\)的\(LCA\)节点。原理也就是之前所说的,\(y\)所在的子树遍历完了,但是\(LCA\)节点\(r\)状态肯定还是\(1\),因为\(r\)还有其他子树没有遍历完,后面再遍历到\(x\)所在的子树时,一方面就说明了\(x\)和\(y\)在\(r\)的不同子树中,另一方面也定位到了\(x\)和\(y\)分属不同子树的根节点\(r\)。
为了提高回溯查找\(LCA\)的效率,可以 使用并查集优化,即一个节点状态转化为\(2\)时,就可以将其合并到其父节点所在的集合中,这样一来,当\(y\)所在的子树全部变为状态\(2\)时,他们也都被合并到\(r\)所在的集合了,就有了\(y\)所在的并查集的根结点就是\(r\),也就是\(x\)和\(y\)的\(LCA\)节点。
特殊情况:\(r\)和\(x\)重合,即\(x\)与\(y\)的\(LCA\)就是\(x\),此时在遍历完\(x\)的所有子树后,\(x\)的状态即将转化为\(2\)时,\(y\)也被合并到以\(x\)为根的并查集中了,此时\(x\)就是\(LCA\)节点。所以我们可以在\(x\)的子树均已遍历回溯完成之际,对\(x\)与状态为\(2\)的\(y\)节点求\(LCA\)。
综上所述,\(lca(x,y)=find(y)\),其中\(find\)函数就是并查集的查找当前集合根节点的函数。并且如果要求\(x\)与\(y\)之间的距离:
注意:并查集的合并操作一定要在当前节点的所有子树都已经遍历回溯完成的情况下,所以要写在\(tarjan\)函数调用的后面,否则像\(r\)节点还没有遍历回溯完就被合并到了\(r\)的父节点所在的集合,后面再对\(y\)求并查集的根节点时就不会返回\(r\)节点了,就会引起错误。
#include <bits/stdc++.h>
using namespace std;
const int N = 10010, M = N << 1;
typedef pair<int, int> PII;
// 查询数组,first:对端节点号,second:问题序号
// 比如:q[2]={5,10} 表示10号问题,计算2和5之间的最短距离
vector<PII> query[N];
int dist[N]; // dist[u]记录从出发点S到u的距离
int res[M]; // 结果数组,有多少个问题就有多少个res[i]
// 链式前向星
int e[M], h[N], idx, w[M], ne[M];
void add(int a, int b, int c = 0) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
// 并查集
int p[N];
int find(int x) {
if (x != p[x]) p[x] = find(p[x]);
return p[x];
}
int st[N]; // 0:未入栈, 1:在栈中, 2:已出栈
void tarjan(int u) {
// ① 标识u已访问
st[u] = 1;
// ② 枚举与u临边相连并且没有访问过的点
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (!st[v]) {
// 扩展:更新距离
dist[v] = dist[u] + w[i];
// 深搜
tarjan(v);
// ③ v加入u家族
p[v] = u;
}
}
// ④ 枚举已完成访问的点,记录lca或题目要求的结果
for (auto q : query[u]) {
int v = q.first, id = q.second;
if (st[v]) res[id] = dist[u] + dist[v] - 2 * dist[find(v)];
}
}
int main() {
int n, m; // n个结点,m次询问
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h); // 初始化链式前向星
for (int i = 1; i <= n; i++) p[i] = i; // 并查集初始化
for (int i = 1; i < n; i++) { // 树有n-1条边
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c); // 无向图
}
// Tarjan算法是离线算法,一次性读入所有的问题,最终一并回答
for (int i = 0; i < m; i++) { // m个询问
int a, b;
scanf("%d%d", &a, &b); // 表示询问点 a 到点 b 的最短距离
query[a].push_back({b, i}), query[b].push_back({a, i}); // 不知道谁先被遍历 所以正反都记一下着
}
// tarjan算法求LCA
tarjan(1);
// 回答m个问题
for (int i = 0; i < m; i++) printf("%d\n", res[i]);
return 0;
}