倍增LCA学习笔记
前言
“倍增”,作为一种二进制拆分思想,广泛用于各中算法,如$ST$表,求解$LCA$等等...今天,我们仅讨论用该思想来求解树上两个节点的$LCA$(最近公共祖先)
“倍增”是什么东西?
倍增就是“成倍增加”的意思,比如$1$倍增后变成了$2$,$2$倍增后就变成了$4$,$4$变成$8$,以此类推...
实现
一直向上LCA
在讲真正的倍增之前,我们先来说说最朴素的$LCA$,对于需要求解的两个点$(x,y)$,我们最先能想到的方法就是两个点先到达同一深度,然后一直往上跳父亲,知道两个点跳到同一个点上,这个点就是$LCA$。
int LCA (int x, int y) {
if (depth[x] < depth[y]) swap(x, y);
while(depth[x] != depth[y]) x = fa[x];
while(x != y) x = fa[x], y = fa[y];
return x;
}
不难发现,这种算法的时间开销很大,我们想办法来优化它。
倍增LCA
就如同$ST$表一样,我们不妨设$f[i][j]$表示树上编号为$i$的节点向上跳$2^j$个节点后所达到的节点,如同$ST$表的预处理,我们很容易发现如何预处理出这个$f$数组:
f[i][j] = f[f[i][j-1]][j-1];
显然,$i$往上跳$2{j-1}$次之后再跳$2$次之后就相当于$i$往上跳$2^j$次,我们可以借此来优化,利用二进制优化背包的思想那样,将跳的次数二进制拆分。
于是,我们改写一下之前的代码
int LCA (int x, int y) {
if (depth[x] < depth[y]) swap(x, y);
for (int i = LogN; i >= 0; --i)
if (depth[f[x][i]] >= depth[y])
x = f[x][i];
if (x == y) return x;
for (int i = LogN; i >= 0; --i)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
这样一来,速度就快很多了,由原来的$O(Depth)$变成了现在的$O(log_2(Depth))$
代码
#include <cstdio>
#include <cstring>
typedef int ll;
const ll N = 5e5 + 10, M = 5e5 + 10, LogN = 25;
ll n, m, s, depth[N], f[N][LogN], a, b, c;
ll from[N], to[M << 1], nxt[M << 1], cnt, tmp, Log[N];
inline void swap (ll &a, ll &b) {tmp = a, a = b, b = tmp;}
//链式前向星加边
void addEdge (ll u, ll v) {
to[++cnt] = v, nxt[cnt] = from[u], from[u] = cnt;
}
//计算深度&计算祖先
void doit (ll u, ll fa) {
depth[u] = depth[fa] + 1;
for (register ll i = 1; i <= Log[n]; ++i) {
if ((1 << i) >= depth[u]) break;
f[u][i] = f[f[u][i - 1]][i - 1];
}
for (register ll i = from[u]; i; i = nxt[i]) {
ll v = to[i];
if (v == fa) continue;
f[v][0] = u;
doit (v, u);
}
}
//计算LCA
inline ll LCA (ll x, ll y) {
if (depth[x] < depth[y]) swap(x, y);
//我们默认x为更深的那个点
for (register ll i = 0; i <= Log[n]; ++i)
if (depth[f[x][i]] >= depth[y])
x = f[x][i];
//将x跳到和y同一深度上
if (x == y) return x;
for (register ll i = Log[n]; i >= 0; --i)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
//一起向上跳
return f[x][0];
//不难看出,此时两个点均在其LCA的下方,往上跳一次即可
}
int main () {
scanf ("%d%d", &n, &m);//n节点数 m询问次数
Log[0] = -1;
for (register ll i = 1, u, v; i < n; ++i) {
scanf ("%d%d", &u, &v);
addEdge (u, v); addEdge(v, u);
Log[i] = Log[i >> 1] + 1;
}
Log[n] = Log[n >> 1] + 1;
doit (1, 0);
while (m--) {
scanf ("%d%d", &a, &b);
printf ("%d\n", LCA(a, b)));
}
return 0;
}