【模板】最近公共祖先(LCA)

posted on 2021-08-04 14:22:40 | under 学术 | source

LCA,Least Common Ancestors,最近公共祖先。

倍增。

首先预处理出数组 \(d_i\)\(f_{i,j}\)

  • \(d_i\) 表示第 \(i\) 个节点的深度。
    转移方程:\(d_{i}=d_{\text{fa}}+1\)
  • \(f_{i,j}\) 表示第 \(i\) 个节点的第 \(2^j\) 级祖先。
    转移方程:\(f_{i,0}=\text{fa},f_{i,j}=f_{f_{i,j-1},j-1}\)

接着是 LCA。分成以下几个步骤:

  1. \(x,y\) 跳到同一层。
  2. 使用倍增思想把 \(x,y\) 跳到离 LCA 最近的节点。
  3. 最终 \(f_{x,0}\)\(f_{y,0}\)\(x,y\) 的 LCA。
点击查看代码
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
template<int N,int M> struct Graph{
    int cnt,head[N+10];
    struct Edge{
        int s,e,w,nxt;
        Edge(int s=0,int e=0,int w=0,int nxt=0):
            s(s),e(e),w(w),nxt(nxt){}
    } a[(M<<1)+10];
    Graph():cnt(0){memset(head,0,sizeof head);}
    void add(int s,int e,int w=0){a[++cnt]=Edge(s,e,w,head[s]),head[s]=cnt;}
    void link(int s,int e,int w=0){add(s,e,w),add(e,s,w);}
};
template<int N> struct Math{
    int lg[N+10];
    Math(){
        lg[0]=-1;
        for(int i=1;i<=N;i++) lg[i]=lg[i>>1]+1;
    }
    int log(int x){return lg[x];}
    int pow(int x){return 1<<x;}
};
int n,m;
Graph<500010,500010> g;
Math<5000010> math;
int f[500010][21],d[500010];
void dfs(int root,int fa){
    f[root][0]=fa,d[root]=d[fa]+1;
    for(int i=1;i<=math.log(d[root]);i++){
        f[root][i]=f[f[root][i-1]][i-1];
    }
    for(int i=g.head[root];i;i=g.a[i].nxt){
        int to=g.a[i].e;
        if(to!=fa) dfs(to,root);
    }
}
int lca(int x,int y){
    if(d[x]<d[y]) swap(x,y);
    int jmp=d[x]-d[y];
    for(int i=math.log(jmp);i>=0;i--){
        if((jmp>>i)&1) x=f[x][i];
    }
    if(x==y) return x;
    for(int i=math.log(n);i>=0;i--){
        if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    }
    return f[x][0];
}
int root;
int main(){
    cin>>n>>m>>root;
    for(int i=1;i<=n-1;i++){int x,y;
        cin>>x>>y;
        g.link(x,y);
    }
    dfs(root,root);
    for(int i=1;i<=m;i++){int x,y;
        cin>>x>>y;
        cout<<lca(x,y)<<endl;
    }
    return 0;
}

事实上 LCA 还有一百种求法,上面写的倍增太过古老(原发布时间 2021 年)。这里再说说几种做法,不细讲了。

  1. 倍增 \(O(n\log n)-O(\log n)\) 上面有
  2. 欧拉环游序 \(O(n\log n)-O(1)\),观察到两个点之间的 LCA 必然是它们路径上深度最小的点;欧拉环游序额外遍历的子树不影响答案;于是 ST 表进行 RMQ。
点击查看代码
#include <cstdio>
#include <cstring>
#include <utility>
#include <algorithm>
using namespace std;
typedef pair<int,int> node;
template<int N,int M,class T=int> struct graph{
    int head[N+10],nxt[M*2+10],cnt;
    struct edge{
        int u,v;T w;
        edge(int u=0,int v=0,T w=0):u(u),v(v),w(w){}
    } e[M*2+10];
    graph(){memset(head,cnt=0,sizeof head);}
    edge operator[](int i){return e[i];}
    void add(int u,int v,T w=0){e[++cnt]=edge(u,v,w),nxt[cnt]=head[u],head[u]=cnt;}
    void link(int u,int v,T w=0){add(u,v,w),add(v,u,w);}
};
template<int N,class T=int,int logN=21> struct STable{
    T f[logN+1][N+10];
    int cnt,lg[N+10];
    STable():cnt(0){lg[0]=lg[1]=0;for(int i=2;i<=N;i++) lg[i]=lg[i>>1]+1;}
    void insert(T x){
        f[0][++cnt]=x;
        for(int j=1;1<<j<=cnt;j++){
            int i=cnt-(1<<j)+1;
            f[j][i]=min(f[j-1][i],f[j-1][i+(1<<j-1)]);
        }
    }
    T query(int l,int r){
        if(l>r) swap(l,r);
        int k=lg[r-l+1];
        return min(f[k][l],f[k][r-(1<<k)+1]);
    }
};
int n,m,dep[500010],pos[500010],cnt,rt;
graph<500010,500010> g;
STable<500010*2,node> t;
void dfs(int u,int fa=0){
    dep[u]=dep[fa]+1,pos[u]=++cnt;
    t.insert(node(dep[u],u));
    for(int i=g.head[u];i;i=g.nxt[i]){
        int v=g[i].v;
        if(v==fa) continue;
        dfs(v,u);
        cnt++,t.insert(node(dep[u],u));
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&rt);
    for(int i=1;i<n;i++){int u,v;
        scanf("%d%d",&u,&v);
        g.link(u,v);
    }
    dfs(rt);
    for(int i=1;i<=m;i++){int u,v;
        scanf("%d%d",&u,&v);
        printf("%d\n",t.query(pos[u],pos[v]).second);
    }
    return 0;
}
  1. 树剖 \(O(n)-O(\log n)\) 那这位更是经典。两点之间的路径拆成多条重链的前缀和一条重链的区间。找到最顶上的重链,取这个区间深度最小的。
点击查看代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=5e5;
struct edge{
	int u,v,nxt;
} e[1000010];
int n,m,cnt,head[N+10];
void add(int u,int v){
	e[++cnt].u=u;
	e[cnt].v=v;
	e[cnt].nxt=head[u];
	head[u]=cnt;
}
int fa[N+10],dep[N+10],siz[N+10],son[N+10];
void dfs(int u,int f){
	fa[u]=f;
	dep[u]=dep[f]+1;
	siz[u]=1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==f) continue;
		dfs(v,u);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v]) son[u]=v;
	}
}
int dfn[N+10],rnk[N+10],top[N+10],tot;
void cut(int u,int topf){
	dfn[++tot]=u;
	rnk[u]=tot;
	top[u]=topf;
	if(son[u]) cut(son[u],topf);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==fa[u]||v==son[u]) continue;
		cut(v,v);
	}
}
int lca(int u,int v){
	for(;top[u]!=top[v];){
		if(dep[top[u]]<dep[top[v]]) swap(u,v);
		u=fa[top[u]];
	}
	if(dep[u]>dep[v]) swap(u,v);
	return u;
}
int root;
int main(){
	scanf("%d%d%d",&n,&m,&root);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v),add(v,u);
	}
	dfs(root,0);
	cut(root,root);
	for(int i=1;i<=m;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		printf("%d\n",lca(u,v));
	}
	return 0;
}
/*
14 13
1 2
1 3
1 4
2 5
2 6
3 7
4 8
4 9
4 10
6 11
6 12
13 9
13 14
*/
  1. LCT:return access(x),access(y) 均摊 \(O(n\log n)\)。Access 的返回值是打通最后一条实链进入时的节点。
点击查看代码

这个实现挺不好的,只需要 splay 和 access 两个操作就行了。其它都是多余。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
template<int N> struct lctree{
	int val[N+10],sum[N+10],fa[N+10],ch[N+10][2],rev[N+10];
	bool getson(int p){return ch[fa[p]][1]==p;}
	bool isroot(int p){return !p||ch[fa[p]][getson(p)]!=p;}
	void maintain(int p){sum[p]=val[p]^sum[ch[p][0]]^sum[ch[p][1]];}
	void pushdown(int p){if(rev[p]) swap(ch[p][0],ch[p][1]),rev[ch[p][0]]^=1,rev[ch[p][1]]^=1,rev[p]^=1;}
	void update(int p){if(!isroot(p)) update(fa[p]); pushdown(p);}
	void connect(int p,int q,int r){fa[p]=q,ch[q][r]=p;}//p->q
	void rotate(int p){int f=fa[p],r=getson(p);if(fa[p]=fa[f],!isroot(f)) connect(p,fa[f],getson(f));connect(ch[p][r^1],f,r),connect(f,p,r^1),maintain(f),maintain(p);}
	void splay(int p){for(update(p);!isroot(p);rotate(p)) if(!isroot(fa[p])) rotate(getson(p)==getson(fa[p])?fa[p]:p);}
	int access(int p){int y=0;for(;p;p=fa[y=p]) splay(p),ch[p][1]=y,maintain(p);return y;}
	void makeroot(int p){access(p),splay(p),rev[p]^=1;}
	int findroot(int p){access(p),splay(p);while(ch[p][0]) p=ch[p][0];return p;}
	void split(int x,int y){makeroot(x),access(y),splay(y);}
	void link(int x,int y){makeroot(x),fa[x]=y;}
	void cut(int x,int y){split(x,y);if(fa[x]==y&&!ch[x][1]) fa[x]=ch[y][0]=0; maintain(y);}
	void modify(int x,int y){splay(x),val[x]=y,maintain(x);}
	int lca(int x,int y){return access(x),access(y);}
};
int n,m,s;
lctree<500010> t;
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d%d%d",&n,&m,&s);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),t.link(u,v);
	t.makeroot(s);
	for(int i=1,u,v;i<=m;i++) scanf("%d%d",&u,&v),printf("%d\n",t.lca(u,v));
	return 0;
}


  1. 离线的 Tarjan 总共 \(O(\alpha\cdot (n+q))\)https://www.cnblogs.com/caijianhong/p/16882023.html
点击查看代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
template<int N,int M,class T=int> struct graph{
    int head[N+10],nxt[M*2+10],cnt;
    struct edge{
        int u,v;T w;
        edge(int u=0,int v=0,T w=0):u(u),v(v),w(w){}
    } e[M*2+10];
    graph(){memset(head,cnt=0,sizeof head);}
    edge&operator[](int i){return e[i];}
    void add(int u,int v,T w=0){e[++cnt]=edge(u,v,w),nxt[cnt]=head[u],head[u]=cnt;}
    void link(int u,int v,T w=0){add(u,v,w),add(v,u,w);}
};
template<int N> struct dsy{
	int fa[N+10];
	dsy(){for(int i=1;i<=N;i++) fa[i]=i;}
	int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
};
int n,m,root;
bool vis[500010];
graph<500010,500010> g,que;
dsy<500010> s;
void dfs(int u,int fa=0){
	vis[u]=1;
	for(int i=g.head[u];i;i=g.nxt[i]){
		int v=g[i].v; if(v==fa) continue;
		dfs(v,u),s.fa[v]=u;
	}
	for(int i=que.head[u];i;i=que.nxt[i]){
		int v=que[i].v;
		if(vis[v]) que[i].w=que[i^1].w=s.find(v);
	}
}
int main(){
	scanf("%d%d%d",&n,&m,&root);
	que.add(0,0);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),g.link(u,v);
	for(int i=1,u,v;i<=m;i++) scanf("%d%d",&u,&v),que.link(u,v);
	dfs(root);
	for(int i=2;i<=que.cnt;i+=2) printf("%d\n",que[i].w);
	return 0;
}

  1. dfn 序求 lca。在欧拉环游序的基础上改动而成。(1) \(u, v\) 没有祖先关系,则查询 \([dfn_u, dfn_v]\) 中深度最小的节点的父亲。(2) \(u\)\(v\) 的祖先,为了判断这种情况,可以记录子树大小,或者将上一种情况的区间左端点 +1 以适配。(3) 最终只需要查询 \([dfn_u+1, dfn_v]\) 中深度最小(或 dfn 最小)的节点的父亲。ST 表解决。特判 \(u=v\)https://www.cnblogs.com/alex-wei/p/DFN_LCA.html
点击查看代码
template <class T, class Compare>
struct STable {
  Compare cmp;
  vector<T> f[21];
  STable(const Compare& cmp = Compare{}, const vector<T>& vec = vector<T>{})
      : cmp(cmp) {
    for (auto&& x : vec) insert(x);
  }
  int insert(const T& x) {
    f[0].push_back(x);
    int n = f[0].size();
    for (int j = 1; 1 << j <= n; j++) {
      int i = n - (1 << j);
      f[j].push_back(min(f[j - 1][i], f[j - 1][i + (1 << (j - 1))], cmp));
    }
    return n - 1;
  }
  T query(int l, int r) {  // [l, r]
    int k = 31 - __builtin_clz(r - l + 1);
    return min(f[k][l], f[k][r - (1 << k) + 1], cmp);
  }
};
int dfn[200010];
const auto cmp = [](int u, int v) { return dfn[u] < dfn[v]; };
STable<int, decltype(cmp)> ST{cmp};
basic_string<int> g[200010];
void dfs(int u, int fa) {
  dfn[u] = ST.insert(fa);
  for (int v : g[u]) dfs(v, u);
}
int lca(int u, int v) {
  if (u == v) return u;
  if (dfn[u] > dfn[v]) swap(u, v);
  return ST.query(dfn[u] + 1, dfn[v]);
}
dfs(1, 0);
posted @ 2023-07-24 20:05  caijianhong  阅读(30)  评论(0编辑  收藏  举报