luogu 模板:https://www.luogu.com.cn/problem/P3379
树剖求 \(LCA\)
时间复杂度 \(O(2n + qlogn)\)
#include<bits/stdc++.h>
using namespace std;
using LL = long long;
struct HLD{
vector<vector<int>> e;
vector<int> top, dep, parent, siz, son;
HLD(int n){
e.resize(n + 1);
top.resize(n + 1);
dep.resize(n + 1);
parent.resize(n + 1);
siz.resize(n + 1);
son.resize(n + 1);
}
void add(int u, int v){
e[u].push_back(v);
e[v].push_back(u);
}
void dfs1(int u){
siz[u] = 1;
dep[u] = dep[parent[u]] + 1;
for (auto v : e[u]){
if (v == parent[u]) continue;
parent[v] = u;
dfs1(v);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int up){
top[u] = up;
if (son[u]) dfs2(son[u], up);
for (auto v : e[u]){
if (v == parent[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int lca(int u, int v){
while (top[u] != top[v]){
if (dep[top[u]] > dep[top[v]]){
u = parent[top[u]];
}
else{
v = parent[top[v]];
}
}
return dep[u] < dep[v] ? u : v;
}
};
int main(){
ios::sync_with_stdio(false);cin.tie(0);
int n, m, s;
cin >> n >> m >> s;
HLD t(n);
for (int i = 0; i < n - 1; i ++ ){
int u, v;
cin >> u >> v;
t.add(u, v);
}
t.dfs1(s);
t.dfs2(s, s);
for (int i = 0; i < m; i ++ ){
int u, v;
cin >> u >> v;
cout << t.lca(u, v) << "\n";
}
return 0;
}
倍增求 \(LCA\)
时间复杂度 预处理时间:\(O(nlogn)\)
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10;
int n, m, root, d[N], p[N][30], lg[N];
vector <int> g[N];
void dfs(int u, int fa){
p[u][0] = fa;
d[u] = d[fa] + 1;
for (int i = 1; i <= lg[d[u]]; i ++ )
p[u][i] = p[p[u][i - 1]][i - 1];
// u 的 2^i 的祖先等于 u 的 2^(i-1) 的祖先的 2^(i-1) 的祖先
for (auto v : g[u])
if (v != fa)
dfs(v, u);
}
int lca(int x, int y){
if(d[x] < d[y]) swap(x, y);
while (d[x] > d[y])
x = p[x][lg[d[x] - d[y]] - 1];
if (x == y) return x;
for (int k = lg[d[x]] - 1; k >= 0; k -- )
if (p[x][k] != p[y][k]){
x = p[x][k];
y = p[y][k];
}
return p[x][0];
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);
cin >> n >> m >> root;
for (int i = 1; i < n; i ++ ){
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 1; i <= n; i ++ ) //预处理 log
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
dfs(root, 0); //找到每个点的祖先
for (int i = 1; i <= m; i ++ ){
int x, y;
cin >> x >> y;
cout << lca(x, y) << "\n";
}
return 0;
}