最近公共祖先 LCA
原创建时间:2018-08-07 14:08:52
LCA 的概念
在图论和计算机科学中,最近公共祖先(英语:lowest common ancestor)是指在一个树或者有向无环图中同时拥有v和w作为后代的最深的节点。
——Wikipedia
简单的来说,就是两个节点v和w的最近的祖先节点
如下图
6和7的LCA是2,3和7的LCA是1
LCA 的求法
暴力求解
让他们一步一步往上爬,直到相遇
显然,这样的算法会T到飞起
所以我们要使用倍增
优化
倍增求解
所谓倍增,就是按2的倍数来增大,也就是跳 1、2、4 、8 、16、32 ...
但是在这里,我们要考虑从大到小跳
因为如果我们从小到大跳,就会出现要「回溯」的情况,因为我们不一定能精准地跳,而从大到小跳可以避开这种情况
对于上面这一棵更复杂的树,我们考虑17和18的LCA
17 ->(跳4) 3
18 ->(跳4) 5 ->(跳1) -> 3
是不是快多了,跳的次数大大减小
时间复杂度\(O(nlogn)\)
LCA 的代码 & 实现流程
实现流程
首先我们要记录各个点的深度\(depth[\ ]\)和它们\(2^i\)级的祖先\(father[\ ][\ ]\)
用\(depth[i]\)表示\(i\)点的深度,\(father[i][j]\)表示\(i\)点的\(2^i\)级的祖先
// 预处理
void dfsInit(int root, int fa) {
depth[root] = depth[fa] + 1;
father[root][0] = fa;
for (int i = 1; (1 << i) <= depth[root]; ++i) {
father[root][i] = father[father[root][i-1]][i-1];
}
for (int e = head[root]; e; e = edge[e].next) {
if (edge[e].prev != fa) dfsInit(edge[e].prev, root);
}
}
接着我们就可以找LCA辣
对了,我们可以让它跑得更快
// 提前预处理出log2i + 1的值
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i-1] + (1 << lg[i-1] == i);
}
在求 LCA 之前,我们先让两个节点蹦到同一层
但是跳的时候不能直接跳到 LCA 上,要跳到 LCA - 1 上,再输出 当前的父节点 就行了
因为直接蹦到 LCA 上可能会出现「误判」,比如上图中\(4\)和\(8\),若不判断,则在跳的时候会输出1,但是答案是3
所以我们就可以让它们跳到\(2\)和\(5\),然后输出父节点
int LCA(int x, int y) {
// 我们设x的深度大于y的深度
if (depth[x] < depth[y]) swap(x, y);
while (depth[x] > depth[y])
x = father[x][lg[depth[x] - depth[y]] - 1];
if (x == y) return x; // x 是 y 的祖先
for (int i = lg[depth[x]]; i >= 0; --i) {
if (father[x][i] != father[y][i]) x = father[x][i], y = father[y][i];
// 不相等就往上跳
}
return father[x][0];
}
完整代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
using namespace std;
const int MAXN = 500000 + 10;
const int MAXM = 500000 + 10;
struct Edge {
int prev, next;
} edge[MAXM * 2];
int head[MAXN], father[MAXN][22], lg[MAXN], depth[MAXN];
int cnt, n, m, s;
inline int getint() {
int s = 0, x = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') x = -1;
ch = getchar();
}
while (isdigit(ch)) {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * x;
}
inline void putint(int x, bool returnValue) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x >= 10) putint(x / 10, false);
putchar(x % 10 + '0');
if (returnValue) putchar('\n');
}
inline void addEdge(int prev, int next) {
edge[++cnt].prev = prev;
edge[cnt].next = head[next];
head[next] = cnt;
}
// 预处理
void dfsInit(int root, int fa) {
depth[root] = depth[fa] + 1;
father[root][0] = fa;
for (int i = 1; (1 << i) <= depth[root]; ++i) {
father[root][i] = father[father[root][i-1]][i-1];
}
for (int e = head[root]; e; e = edge[e].next) {
if (edge[e].prev != fa) dfsInit(edge[e].prev, root);
}
}
int LCA(int x, int y) {
// 我们设x的深度大于y的深度
if (depth[x] < depth[y]) swap(x, y);
while (depth[x] > depth[y])
x = father[x][lg[depth[x] - depth[y]] - 1];
if (x == y) return x; // x 是 y 的祖先
for (int i = lg[depth[x]]; i >= 0; --i) {
if (father[x][i] != father[y][i]) x = father[x][i], y = father[y][i];
// 不相等就往上跳
}
return father[x][0];
}
int main(int argc, char *const argv[]) {
n = getint(), m = getint(), s = getint();
for (int i = 1; i < n; ++i) {
int prev = getint(), next = getint();
addEdge(prev, next);
addEdge(next, prev);
}
dfsInit(s, 0);
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i-1] + (1 << lg[i-1] == i);
}
for (int i = 1; i <= m; ++i) {
int x = getint(), y = getint();
putint(LCA(x, y), true);
}
return 0;
}