树上问题/简单算法 LCA【最近公共祖先】
概念引入
- 最近公共祖先简称
(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远的那个。
在下面的说明中,我们设两个节点分别为
算法详解:
朴素算法:
先将节点
那么易知时间复杂度为
倍增算法:
对于朴素算法中的:
“然后两个一起一个一个往上跳”
我们可以用倍增的方式完成跳跃.
那么如何实现呢?
我们知道,任何一个非负整数都可以进行二进制拆分.例如:
那么如果需要向上跳的次数为
( 对于将
接下来考虑如何实现
我们设
- 预处理
数组:
我们知道:
那么我们可以将向上走
则有:
因为
for(int i = 1;i <= log2(n);i++){
for(int j = 1;j <= n;j++){
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
- 将
, 提升到同一高度:
主要思想是将
if(dep[x] < dep[y]) swap(x,y);//保证x的深度大于y,方便后面的计算
int delta = dep[x] - dep[y];
for(int i = 0;i <= log2(n);i++){
if((1 << i) & delta) x = fa[x][i];//向上跳,本质是对delta进行二进制拆分
}
, 同时向上跳
这里有一个细节,即是将
for(int i = log2(n);i >= 0;i--){
if(fa[x][i] != fa[y][i]){
//Leap!
x = fa[x][i];
y = fa[y][i];
}
}
我们可以将第二步和第三步封装成一个函数:
int lca(int x,int y){
if(dep[x] < dep[y]) swap(x,y);
int delta = dep[x] - dep[y];
for(int i = 0;i <= log2(n);i++){
if((1 << i) & delta) x = fa[x][i];
}
if(x == y){
return x;
}
for(int i = log2(n);i >= 0;i--){
if(fa[x][i] != fa[y][i]){//这里最终跳跃到的是lca(x,y)的子节点
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
那么我们的倍增算法就好了,可知其时间复杂度为
这里附上完整代码:
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 5e5 + 7;
const int LOG = 30;
int n,m,r;
int dep[MAXN];
int fa[MAXN][LOG];
bool vis[MAXN];
vector<int> tree[MAXN];
void dfs(int root){
vis[root] = true;
if(tree[root].size() == 0){
return;
}
for(int to : tree[root]){
if(!vis[to]){
dep[to] = dep[root] + 1;
fa[to][0] = root;
dfs(to);
}
}
}
int lca(int x,int y){
//Leap to a same depth:
int dx = dep[x],dy = dep[y];
if(dep[x] != dep[y]){
if(dx < dy){
swap(dx,dy);
swap(x,y);
}
int delta = dx - dy;
for(int i = 0;i <= LOG - 2;i++){
if((1 << i) & delta) x = fa[x][i];
}
}
if(x == y){
return x;
}
//Leap to the child node of lca(x,y)
for(int i = LOG - 2;i >= 0;i--){
if(fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int main(){
scanf("%d%d%d", &n, &m, &r);
for(int i = 1;i < n;i++){
int x,y;
scanf("%d%d", &x, &y);
tree[x].push_back(y);
tree[y].push_back(x);
}
dep[r] = 1;
fa[r][0] = 0;
//Pre-Processing
dfs(r);
for(int i = 1;i <= LOG - 2;i++){
for(int j = 1;j <= n;j++){
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
for(int i = 1;i <= m;i++){
int x,y;
scanf("%d%d", &x, &y);
printf("%d\n", lca(x,y));
}
return 0;
}
本文来自博客园,作者:wyl123ly,转载请注明原文链接:https://www.cnblogs.com/wyl123ly/p/18304549
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步