LCA

LCA最近公共祖先

最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远的那个。

倍增法

暴力求解太慢,这里先摘记一种做法-倍增法
时间复杂度:$O(n \log_2 n + m \log_2 n) $

对于每一个节点,我们先通过dfs预先处理出当前节点向根移动2的幂次方的节点编号(有种st表的感觉)。转移方程:\(fa[u][i]=fa[fa[u][i-1]][i-1],i\ge 1\)
关键数组:深度数组dep和父节点?数组fa

void dfs(int x,int f){
	dep[x]=dep[f]+1;
	fa[x][0]=f;
	for(int i=1;i<=N;++i){
		fa[x][i]=fa[fa[x][i-1]][i-1];
	}
	for(int i=head[x];i;i=p[i].next){
		if(p[i].to!=f) dfs(p[i].to,x);
	}
	return;
}

对于询问,首先让u与v处于同一深度,此处排除当前u和v即为最近公共祖先的情况。之后让u和v一起向根节点游,\(fa[u][i]!=fa[v][i]\),则继续向上游。直到最后答案为\(fa[v][0]\)

int lca(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);//使u深度大
	for(int i=N;i>=0;--i){
		if(dep[fa[u][i]]>=dep[v]){//为了处于同一深度
			u=fa[u][i];
		}
	}
	if(u==v) return v;//当前节点即为最近公共祖先
	for(int i=N;i>=0;--i){
		if(fa[u][i]!=fa[v][i]){//一起向上不跳到公共点
			u=fa[u][i];
			v=fa[v][i];
		}
	}
	return fa[v][0];
}

树链剖分法

可以利用树链剖分来求 LCA
简单的 LCA 问题仅用到 5 个数组:

  • fa[u]:存节点 u 的父节点

  • dep[u]:存节点 u 的深度

  • son[u]:存节点 u 的重儿子

  • sz[u]:存以节点 u 为根的子树的节点个数

  • top[u]:存 u 所在的重链的顶点,即链头

利用两个 dfs 处理这五个数组(4 + 1)
(下面是用邻接表实现的)
dfs1 :预处理数组 dep、sz、fa、son,初始为 dfs1 (根节点, 0)
从根开始遍历每个节点,统计深度、父节点和节点数初值,最主要就是递归后判断重儿子的位置

void dfs1(int u, int f){//预处理dep、sz、fa、son
	dep[u] = dep[f] + 1;
	sz[u] = 1;
	fa[u] = f;
	for(auto v : e[u]){
		if(v == f) continue;
		dfs1(v, u);
		sz[u] += sz[v];
		if(sz[v] > sz[son[u]]) son[u] = v;
	}
	return ;
}

dfs2 :预处理数组 top,初始为 dfs2 (根节点,根节点)
首先赋值链头,如果没有重儿子直接返回,再 dfs 重儿子(重儿子的链头是当前的链头),最后遍历轻儿子,但是轻儿子的链头是自己

void dfs2(int u, int t){//预处理top
	top[u] = t;
	if(!son[u]) return ;
	dfs2(son[u], t);
	for(auto v : e[u]){
		if(v == fa[u] || v == son[u]) continue;
		dfs2(v, v);
	}
	return ;
}

lca

  • 当两个节点 x, y 位于一条重链上时,深度小的即为 LCA
  • 当两个节点 x, y 位于不同的重链上时,让 x 和 y 向上跳,直到位于同一条重链为止。跳的话链头深度大的节点跳到当前链头的父节点
int lca(int u,int v){
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		u = fa[top[u]];
	}
	return dep[u] < dep[v] ? u : v;
}

应用

树上两点之间的最短距离

\(dis(x,y) = dep[x] + dep[y] - dep[lca(x,y)] * 2\) 或者 \(dist(x,y) = dis[x] + dis[y] - dis[lca(x,y)] * 2\)

树上差分

对于树上某一路径的修改可以视为树上的区间修改,可以利用差分将其变为端点修改
我们把一条路径 $(u,v) $ 分为两部分:$u -> lca(u, v) $ 和 $ lca(u, v) -> v $,这样每条路径都可以视为一个区间处理。
比如说将这个区间所有的节点权值 + x,那么进行如下操作即可:
令 $l = lca(u, v) $,则 $d[u] += x; d[v] += x; d[l] -= x; d[f[l][0]] -= x $
因为 u 和 v 两条路径在 l 处汇合,所以需要在 \(l\)\(f[l][0]\) 处均进行减的操作
如果说想求节点的权值,利用 dfs 先将所有子节点的权值处理完,再加上即可得到当前节点的权值

例题

模板题

倍增法

//>>>Qiansui
#include<map>
#include<set>
#include<list>
#include<stack>
#include<cmath>
#include<queue>
#include<deque>
#include<cstdio>
#include<string>
#include<vector>
#include<utility>
#include<iomanip>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<functional>
#define ll long long
#define ull unsigned long long
#define mem(x,y) memset(x,y,sizeof(x))
#define debug(x) cout << #x << " = " << x << endl
#define debug2(x,y) cout << #x << " = " << x << " " << #y << " = "<< y << endl
//#define int long long

inline ll read()
{
	ll x=0,f=1;char ch=getchar();
	while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
	while (ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}
	return x*f;
}

using namespace std;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<ull,ull> pull;
typedef pair<double,double> pdd;
/*
最近公共祖先模板题
利用深搜构造lca数组
*/
const int maxm=5e5+5,inf=0x3f3f3f3f,mod=998244353,N=20;
int n,m,s,cnt=1,head[maxm],dep[maxm],fa[maxm][N+1];

struct node{
	int to,next;
}p[maxm<<1];

void add_edge(int a,int b){
	p[cnt].to=b;
	p[cnt].next=head[a];
	head[a]=cnt++;
	return ;
}

void dfs(int x,int f){
	dep[x]=dep[f]+1;
	fa[x][0]=f;
	for(int i=1;i<=N;++i){
		fa[x][i]=fa[fa[x][i-1]][i-1];
	}
	for(int i=head[x];i;i=p[i].next){
		if(p[i].to!=f) dfs(p[i].to,x);
	}
	return;
}

int lca(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=N;i>=0;--i){
		if(dep[fa[u][i]]>=dep[v]){
			u=fa[u][i];
		}
	}
	if(u==v) return v;
	for(int i=N;i>=0;--i){
		if(fa[u][i]!=fa[v][i]){
			u=fa[u][i];
			v=fa[v][i];
		}
	}
	return fa[v][0];
}

void solve(){
	cin>>n>>m>>s;
	int a,b;
	for(int i=1;i<n;++i){
		cin>>a>>b;
		add_edge(a,b);
		add_edge(b,a);
	}
	dfs(s,0);
	while(m--){
		cin>>a>>b;
		cout<<lca(a,b)<<'\n';
	}
	return ;
}

signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	int _=1;
//	cin>>_;
	while(_--){
		solve();
	}
	return 0;
}

树链剖分法

//>>>Qiansui
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define mem(x,y) memset(x, y, sizeof(x))
#define debug(x) cout << #x << " = " << x << '\n'
#define debug2(x,y) cout << #x << " = " << x << " " << #y << " = "<< y << '\n'
//#define int long long

using namespace std;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<double, double> pdd;
/*
树链剖分
*/
const int maxm = 5e5 + 5, inf = 0x3f3f3f3f, mod = 998244353;
vector<int> e[maxm];
int dep[maxm], fa[maxm], son[maxm], sz[maxm];//深度、父节点、重儿子、子树节点数
int top[maxm];//链头
int n, m, s;

void dfs1(int u, int f){//预处理dep、sz、fa、son
	dep[u] = dep[f] + 1;
	sz[u] = 1;
	fa[u] = f;
	for(auto v : e[u]){
		if(v == f) continue;
		dfs1(v, u);
		sz[u] += sz[v];
		if(sz[v] > sz[son[u]]) son[u] = v;
	}
	return ;
}

void dfs2(int u, int t){//预处理top
	top[u] = t;
	if(!son[u]) return ;
	dfs2(son[u], t);
	for(auto v : e[u]){
		if(v == fa[u] || v == son[u]) continue;
		dfs2(v, v);
	}
	return ;
}

int lca(int u,int v){
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		u = fa[top[u]];
	}
	return dep[u] < dep[v] ? u : v;
}

void solve(){
	cin >> n >> m >> s;
	int a, b;
	for(int i = 1; i < n; ++ i){
		cin >> a >> b;
		e[a].push_back(b);
		e[b].push_back(a);
	}
	dfs1(s, 0);
	dfs2(s, s);
	while(m --){
		cin >> a >> b;
		cout << lca(a, b) << '\n';
	}
	return ;
}

signed main(){
	ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
	int _ = 1;
	// cin >> _;
	while(_ --){
		solve();
	}
	return 0;
}

根据题意抽象为树,再利用倍增优化。具体代码与lca模板题巨像
代码:
链式前向星 qiansui_code
邻接表 qiansui_code

树上两点间的最近距离

法一:
利用式子:\(dist(x,y) = dis[x] + dis[y] - dis[lca(x,y)] * 2\)
法二:
在 lca 预处理最近公共祖先的同时存下到祖先的距离,之后询问时将距离划分为两段:$u -> lca(u, v) + v -> lca(u, v) $即可

下为代码
法一

//>>>Qiansui
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define mem(x,y) memset(x, y, sizeof(x))
#define debug(x) cout << #x << " = " << x << '\n'
#define debug2(x,y) cout << #x << " = " << x << " " << #y << " = "<< y << '\n'
//#define int long long

using namespace std;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<double, double> pdd;
/*
lca求树上两点距离
*/
const int maxm = 2e5 + 5, inf = 0x3f3f3f3f, mod = 998244353, N = 20;
int n, m;
vector<ll> dep, dis;
vector<vector<ll>> f;
vector<vector<pll>> e;

void dfs(int u, int fa){
	dep[u] = dep[fa] + 1;
	f[u][0] = fa;
	for(int i = 1; i <= N; ++ i){
		f[u][i] = f[f[u][i - 1]][i - 1];
	}
	for(auto v : e[u]){
		if(v.first == fa) continue;
		dis[v.first] = dis[u] + v.second;
		dfs(v.first, u);
	}
	return ;
}

int lca(int u,int v){
	if(dep[u] < dep[v]) swap(u, v);
	for(int i = N; i >= 0; -- i){
		if(dep[f[u][i]] >= dep[v]){
			u = f[u][i];
		}
		if(u == v) return u;
	}
	for(int i = N; i >= 0; -- i){
		if(f[u][i] != f[v][i]){
			u = f[u][i]; v = f[v][i];
		}
	}
	return f[u][0];
}

void solve(){
	cin >> n >> m;
	e = vector<vector<pll>> (n + 1, vector<pll>());
	dep = vector<ll>(n + 1, 0);
	dis = vector<ll>(n + 1, 0);
	f = vector<vector<ll>> (n + 1, vector<ll>(N + 1, 0));
	int x, y, z;
	for(int i = 1; i < n; ++ i){
		cin >> x >> y >> z;
		e[x].push_back({y, z});
		e[y].push_back({x, z});
	}
	dfs(1, 0);
	for(int i = 0; i < m; ++ i){
		cin >> x >> y;
		cout << dis[x] + dis[y] - 2 * dis[lca(x, y)] << '\n';
	}
	return ;
}

signed main(){
	ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
	int _ = 1;
	cin >> _;
	while(_ --){
		solve();
	}
	return 0;
}

法二

//>>>Qiansui
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define mem(x,y) memset(x, y, sizeof(x))
#define debug(x) cout << #x << " = " << x << '\n'
#define debug2(x,y) cout << #x << " = " << x << " " << #y << " = "<< y << '\n'
//#define int long long
 
using namespace std;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<double, double> pdd;
/*
lca求树上两点距离
*/
const int maxm = 2e5 + 5, inf = 0x3f3f3f3f, mod = 998244353, N = 20;
int n, m;
vector<int> dep;
vector<vector<ll>> f, g;
vector<vector<pll>> e;

void dfs(int u, int fa, int len){
	dep[u] = dep[fa] + 1;
	f[u][0] = fa;
	if(fa) g[u][0] = len;
	for(int i = 1; i <= N; ++ i){
		f[u][i] = f[f[u][i - 1]][i - 1];
		g[u][i] = g[u][i - 1] + g[f[u][i - 1]][i - 1];
	}
	for(auto v : e[u]){
		if(v.first == fa) continue;
		dfs(v.first, u, v.second);
	}
	return ;
}

int lca(int u,int v){
	if(dep[u] < dep[v]) swap(u, v);
	for(int i = N; i >= 0; -- i){
		if(dep[f[u][i]] >= dep[v]){
			u = f[u][i];
		}
		if(u == v) return u;
	}
	for(int i = N; i >= 0; -- i){
		if(f[u][i] != f[v][i]){
			u = f[u][i]; v = f[v][i];
		}
	}
	return f[u][0];
}

ll calc(int u, int v){
	ll ans = 0;
	for(int i = N; i >= 0; -- i){
		if(dep[f[u][i]] >= dep[v]){
			ans += g[u][i];
			u = f[u][i];
		}
	}
	return ans;
}

void solve(){
	cin >> n >> m;
	e = vector<vector<pll>> (n + 1, vector<pll>());
	dep = vector<int>(n + 1, 0);
	f = vector<vector<ll>> (n + 1, vector<ll>(N + 1, 0));
	g = vector<vector<ll>> (n + 1, vector<ll>(N + 1, 0));
	int x, y, z;
	for(int i = 1; i < n; ++ i){
		cin >> x >> y >> z;
		e[x].push_back({y, z});
		e[y].push_back({x, z});
	}
	dfs(1, 0, 0);
	for(int i = 0; i < m; ++ i){
		cin >> x >> y;
		int l = lca(x, y);
		ll ans = calc(x, l) + calc(y, l);
		cout << ans << '\n';
	}
	return ;
}

signed main(){
	ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
	int _ = 1;
	cin >> _;
	while(_ --){
		solve();
	}
	return 0;
}

  • hdu 2874 Connections between cities
    求树上两点距离。此题是在森林里面求,所以还需要套一个并查集判断是否在一棵树上
    写的时候 WA 疯了,因为 lca 的板子写错了。。。。。。
//>>>Qiansui
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define mem(x,y) memset(x, y, sizeof(x))
#define debug(x) cout << #x << " = " << x << '\n'
#define debug2(x,y) cout << #x << " = " << x << " " << #y << " = "<< y << '\n'
//#define int long long

using namespace std;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<double, double> pdd;
/*

*/
const int maxm = 1e4 + 5, inf = 0x3f3f3f3f, mod = 998244353, N = 20;
ll n, m, c, dep[maxm], f[maxm][N + 1], dis[maxm];
vector<pll> e[maxm];

struct dsu{
	int num;
	vector<int> fa;
	dsu(int x=maxm):num(x),fa(x+1){
		for(int i=0;i<=x;++i) fa[i]=i;
	}
	int findfa(int x){ return fa[x]==x? x:fa[x]=findfa(fa[x]); }
	void merge(int u,int v){
		fa[findfa(u)]=findfa(v); return ;
	}
};

void dfs(int u, int fa){
	dep[u] = dep[fa] + 1;
	f[u][0] = fa;
	for(int i = 0; i < N; ++ i){
		f[u][i + 1] = f[f[u][i]][i];
	}
	for(auto v : e[u]){
		if(v.first == fa) continue;
		dis[v.first] = dis[u] + v.second;
		dfs(v.first, u);
	}
	return ;
}

int lca(int u, int v){
	if(dep[u] < dep[v]) swap(u, v);
	for(int i = N; i >= 0; -- i){
		if(dep[f[u][i]] >= dep[v]){
			u = f[u][i];
		}
		if(u == v) return u;
	}
	for(int i = N; i >= 0; -- i){
		if(f[u][i] != f[v][i]){
			u = f[u][i]; v = f[v][i];
		}
	}
	return f[u][0];
}

void solve(){
	while(cin >> n >> m >> c){
		dsu ds(n + 1);
		mem(dep, 0); mem(f, 0); mem(dis, 0);
		ll u, v, w;
		for(int i = 0; i < m; ++ i){
			cin >> u >> v >> w;
			e[u].push_back({v, w});
			e[v].push_back({u, w});
			ds.merge(u, v);
		}
		for(int i = 1; i <= n; ++ i){
			if(ds.findfa(i) == i) dfs(i, 0);
		}
		for(int i = 0; i < c; ++ i){
			cin >> u >> v;
			if(ds.findfa(u) != ds.findfa(v))
				cout << "Not connected\n";
			else{
				cout << dis[u] + dis[v] - 2 * dis[lca(u, v)] << '\n';
			}
		}
		for(int i = 1; i <= n; ++ i){
			e[i].clear();
		}
	}
	return ;
}

signed main(){
	ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
	int _ = 1;
	// cin >> _;
	while(_ --){
		solve();
	}
	return 0;
}


树上差分

  • easy 题 洛谷 P3128 [USACO15DEC] Max Flow P
    简单的树上差分,最后 dfs 求所有节点的权值最大值。关键部分即为差分的部分,可以看代码或者上面的介绍部分
//>>>Qiansui
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define mem(x,y) memset(x, y, sizeof(x))
#define debug(x) cout << #x << " = " << x << '\n'
#define debug2(x,y) cout << #x << " = " << x << " " << #y << " = "<< y << '\n'
//#define int long long

using namespace std;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<double, double> pdd;
/*
树上差分
*/
const int maxm = 5e4 + 5, inf = 0x3f3f3f3f, mod = 998244353, N = 18;
int n, k, f[maxm][N + 1], dep[maxm], cnt[maxm], d[maxm], ans = 0;
vector<int> e[maxm];

void dfs(int u, int fa){
	dep[u] = dep[fa] + 1;
	f[u][0] = fa;
	for(int i = 1; i <= N; ++ i){
		f[u][i] = f[f[u][i - 1]][i - 1];
	}
	for(auto v : e[u]){
		if(v != fa) dfs(v, u);
	}
	return ;
}

int lca(int x, int y){
	if(dep[x] < dep[y]) swap(x,y);
	for(int i = N; i >= 0; -- i){
		if(dep[f[x][i]] >= dep[y]){
			x = f[x][i];
		}
		if(x == y) return x;
	}
	for(int i = N; i >= 0; -- i){
		if(f[x][i] != f[y][i]){
			x = f[x][i];
			y = f[y][i];
		}
	}
	return f[x][0];
}

void ddfs(int u, int fa){
	for(auto v : e[u]){
		if(v == fa) continue;
		ddfs(v, u);
		d[u] += d[v];
	}
	ans = max(ans, d[u]);
	return ;
}

void solve(){
	cin >> n >> k;
	int x, y;
	for(int i = 1; i < n; ++ i){
		cin >> x >> y;
		e[x].push_back(y);
		e[y].push_back(x);
	}
	dfs(1, 0);
	for(int i = 0; i < k; ++ i){// 差分处理
		cin >> x >> y;
		int l = lca(x, y);
		++ d[x]; ++ d[y];
		-- d[l];
		-- d[f[l][0]];
	}
	ddfs(1, 0);
	cout << ans << '\n';
	return ;
}

signed main(){
	ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
	int _ = 1;
	// cin >> _;
	while(_ --){
		solve();
	}
	return 0;
}


综合题

相关资料

整体介绍
https://oi-wiki.org/graph/lca/

模板及视频讲解:
倍增法
树链剖分法

posted on 2023-08-02 15:50  Qiansui  阅读(18)  评论(0编辑  收藏  举报