LCA模板
具体讲解可看:https://www.cnblogs.com/zhouzhendong/p/7256007.html
LCA_Tarjan
Tarjan 算法求 LCA 的时间复杂度为 O((n+q)α(n)) ,是一种离线算法,要用到并查集。
#include <bits/stdc++.h> using namespace std; const int N=40000+5; struct Edge{ int cnt,x[N],y[N],z[N],nxt[N],fst[N]; void set(){ cnt=0; memset(x,0,sizeof x); memset(y,0,sizeof y); memset(z,0,sizeof z); memset(nxt,0,sizeof nxt); memset(fst,0,sizeof fst); } void add(int a,int b,int c){ x[++cnt]=a; y[cnt]=b; z[cnt]=c; nxt[cnt]=fst[a]; fst[a]=cnt; } }e,q; int T,n,m,from,to,dist,in[N],rt,dis[N],fa[N],ans[N]; bool vis[N]; void dfs(int rt){ for (int i=e.fst[rt];i;i=e.nxt[i]){ dis[e.y[i]]=dis[rt]+e.z[i]; dfs(e.y[i]); } } int getf(int k){ return fa[k]==k?k:fa[k]=getf(fa[k]); } void LCA(int rt){ for (int i=e.fst[rt];i;i=e.nxt[i]){ LCA(e.y[i]); fa[getf(e.y[i])]=rt; } vis[rt]=1; for (int i=q.fst[rt];i;i=q.nxt[i]) if (vis[q.y[i]]&&!ans[q.z[i]]) ans[q.z[i]]=dis[q.y[i]]+dis[rt]-2*dis[getf(q.y[i])]; } int main(){ scanf("%d",&T); while (T--){ q.set(),e.set(); memset(in,0,sizeof in); memset(vis,0,sizeof vis); memset(ans,0,sizeof ans); scanf("%d%d",&n,&m); for (int i=1;i<n;i++) scanf("%d%d%d",&from,&to,&dist),e.add(from,to,dist),in[to]++; for (int i=1;i<=m;i++) scanf("%d%d",&from,&to),q.add(from,to,i),q.add(to,from,i); rt=0; for (int i=1;i<=n&&rt==0;i++) if (in[i]==0) rt=i; dis[rt]=0; dfs(rt); for (int i=1;i<=n;i++) fa[i]=i; LCA(rt); for (int i=1;i<=m;i++) printf("%d\n",ans[i]); } return 0; }
倍增
我们可以用倍增来在线求 LCA ,时间和空间复杂度分别是 O((n+q)logn)和 O(nlogn)
#include <bits/stdc++.h> using namespace std; const int N=10000+5; vector <int> son[N]; int T,n,depth[N],fa[N],in[N],a,b; void dfs(int prev,int rt){ depth[rt]=depth[prev]+1; fa[rt]=prev; for (int i=0;i<son[rt].size();i++) dfs(rt,son[rt][i]); } int LCA(int a,int b){ if (depth[a]>depth[b]) swap(a,b); while (depth[b]>depth[a]) b=fa[b]; while (a!=b) a=fa[a],b=fa[b]; return a; } int main(){ scanf("%d",&T); while (T--){ scanf("%d",&n); for (int i=1;i<=n;i++) son[i].clear(); memset(in,0,sizeof in); for (int i=1;i<n;i++){ scanf("%d%d",&a,&b); son[a].push_back(b); in[b]++; } depth[0]=-1; int rt=0; for (int i=1;i<=n&&rt==0;i++) if (in[i]==0) rt=i; dfs(0,rt); scanf("%d%d",&a,&b); printf("%d\n",LCA(a,b)); } return 0; }
RMQ
RMQ可以 O(nlogn)预处理,O(1)在线查询的算法
#include <bits/stdc++.h> #define time _____time using namespace std; const int N=50005; struct Gragh{ int cnt,y[N*2],z[N*2],nxt[N*2],fst[N]; void clear(){ cnt=0; memset(fst,0,sizeof fst); } void add(int a,int b,int c){ y[++cnt]=b,z[cnt]=c,nxt[cnt]=fst[a],fst[a]=cnt; } }g; int n,m,depth[N],in[N],out[N],time; int ST[N*2][20]; void dfs(int x,int pre){ in[x]=++time; ST[time][0]=x; for (int i=g.fst[x];i;i=g.nxt[i]) if (g.y[i]!=pre){ depth[g.y[i]]=depth[x]+g.z[i]; dfs(g.y[i],x); ST[++time][0]=x; } out[x]=time; } void Get_ST(int n){ for (int i=1;i<=n;i++) for (int j=1;j<20;j++){ ST[i][j]=ST[i][j-1]; int v=i-(1<<(j-1)); if (v>0&&depth[ST[v][j-1]]<depth[ST[i][j]]) ST[i][j]=ST[v][j-1]; } } int RMQ(int L,int R){ int val=floor(log(R-L+1)/log(2)); int x=ST[L+(1<<val)-1][val],y=ST[R][val]; if (depth[x]<depth[y]) return x; else return y; } int main(){ scanf("%d",&n); for (int i=1,a,b,c;i<n;i++){ scanf("%d%d%d",&a,&b,&c); a++,b++; g.add(a,b,c); g.add(b,a,c); } time=0; dfs(1,0); depth[0]=1000000; Get_ST(time); scanf("%d",&m); while (m--){ int x,y; scanf("%d%d",&x,&y); if (in[x+1]>in[y+1]) swap(x,y); int LCA=RMQ(in[x+1],in[y+1]); printf("%d\n",depth[x+1]+depth[y+1]-depth[LCA]*2); } return 0; }