LCA
LCA,全称为Lowest Common Ancestor, 即最近公共祖先。这是对于有根树而言的,两个节点u, v的公共祖先中距离最近的那个被称为最近公共祖先(这解释。。真通俗。。。)我们来看个图:
4和7的LCA是2,5和6的LCA是1,2和5的LCA是2。
最笨的实现方法就是;
对于同一深度的点,一个个往上走知道走到相同的点为止;不同深度的点转化为同一深度的点再往上走!
复杂度为O(dep[u] + dep[v] - 2 * dep[c])(u, v表示要查询的点,c为u, v的LCA)
代码如下:
#include <cstdio> #include <vector> #include <algorithm> using namespace std; const int N = 100000 + 5; vector<int> G[N]; int rt;//表示根,即root int dsu[N], dep[N]; //dsu数组表示父亲,dep表示深度 void dfs(int v, int p, int d){ dsu[v] = p; dep[v] = d; for(int i = 0; i < G[v].size(); i ++) if(G[v][i] != p) dfs(G[v][i], v, d + 1); //G[v][i] != p的原因是读入的时候是无向,所以存在反向边,例在处理节点v的时候,假设dsu[v] = u, 则存在i,使得G[v][i] = u,这样会导致在两个点中一直循环,导致陷入死循环。 } void init(){ dfs(rt, -1, 0); } int lca(int u, int v){ if(dep[u] > dep[v]) swap(u, v); while(dep[v] > dep[u]) v = dsu[v]; //让u, v处于同一深度。 while(u != v){ u = dsu[u]; v = dsu[v]; } return u; } int n, q, a, b; int main(){ //我们以n个点n - 1条边,q次询问为例 while(scanf("%d%d", &n, &q) == 2){ for(int i = 1; i <= n; i ++)G[i].clear(); for(int i = 1; i < n; i ++){ scanf("%d%d", &a, &b); G[a].push_back(b); G[b].push_back(a); } rt = 1;//以1为根 init(); while(q --){ scanf("%d%d", &a, &b); printf("%d\n", lca(a, b)); } } return 0; }
我们发现如果有n个点最坏的情况下有O(n)的复杂度,如果多次查询复杂度肯定会爆掉,所以我们必须要有高效的算法。
实现LCA的高效算法有二种,分别是倍增法和RMQ法。
一.倍增法
我们首先这们想如果相同深度的两个节点u, v当往上走k步的时候走到同一节点,那么往上走k + 1步还是同一节点,k + 2步也是, k + k即2k步也是。我们把上面的dsu数组变成2维数组dsu[k][v]表示节点v往上走2k步所走到的节点。那么dsu[k + 1][v] = dsu[k][dsu[k][v]];这样我们就可以通过二分来查找他们的LCA了,每次查询复杂度为O(logn), 预处理为O(nlogn)。
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; const int N = 10000 + 5; vector<int> G[N]; int rt, a, b, c, q, n, m, T; int dsu[40][N], dep[N]; void dfs(int v, int p, int d){ dsu[0][v] = p; dep[v] = d; for(int i = 0; i < G[v].size(); i ++) if(G[v][i] != p) dfs(G[v][i], v, d + 1); } void init(){ dfs(rt, -1, 0); for(int k = 0; k + 1 < 32; k ++){ for(int v = 1; v <= n; v ++){ if(dsu[k][v] < 0)dsu[k + 1][v] = -1;//根节点的父亲设为-1 else dsu[k + 1][v] = dsu[k][dsu[k][v]]; } } } int lca(int u, int v){ if(dep[u] > dep[v]) swap(u, v); for(int k = 0; k < 32; k ++){ if((dep[v] - dep[u]) >> k & 1) v = dsu[k][v]; } if(u == v)return u; for(int k = 31; k >= 0; k --){ if(dsu[k][u] != dsu[k][v]){ u = dsu[k][u]; v = dsu[k][v]; } } return dsu[0][u]; } int main(){ while(scanf("%d%d", &n, &q) == 2){ for(int i = 1; i <= N; i ++)G[i].clear(); for(int i = 1; i < n; i ++){ scanf("%d%d", &a, &b); G[a].push_back(b); G[b].push_back(a); } rt = 1; init(); while(q --){ scanf("%d%d", &a, &b); printf("%d\n", lca(a, b)); } } return 0; }
二.RMQ法
其实这里还涉及到另外一个东西,叫做dfs序。是指你用dfs遍历一棵树时,每个节点会按照遍历到的先后顺序得到一个序号。然后你用这些序号,可以把整个遍历过程表示出来。如下图:
如上图所示,则整个遍历过程为1 2 4 2 5 7 5 8 5 2 1 3 6 3 1
我们将他保存在一个vs数组,并开个id数组记录第一次在vs中出现的下标,例如id[1] = 1, id[4] = 3;
并用dep数组储存vs数组中每个数的深度,例如dep[2] = dep[4] = 1(vs数组中第2个和第4个都是2,2的深度为2)。
而LCA(u, v)就是第一次访问u之后到第一次访问v之前所经过顶点中离根最近的那个。假设id[u] <= id[v],那么LCA(u, v) = vs[t] t为id[u]与id[v]中dep最小的那一个。
而这个不就相当于求区间的RMQ吗?
附上代码:
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; const int N = 100000 + 5; vector<int>G[N]; int rt, n, m, a, b, c, q; int vs[N * 2 - 1], dep[N * 2 - 1], id[N], sum[N]; int dp[N][40]; struct RMQ{ int log2[N]; void init(int n){ log2[0] = -1; for(int i = 1; i <= n; i ++)log2[i] = log2[i >> 1] + 1; for(int i = 1; i <= n; i ++)dp[i][0] = vs[i]; for(int j = 1; j <= log2[n]; j ++){ for(int i = 1; i + (1 << j) <= n + 1; i ++){ int ca = dp[i][j - 1]; int cb = dp[i + (1 << (j - 1))][j - 1]; if(vs[ca] < vs[cb]) dp[i][j] = ca; else dp[i][j] = cb; } } } int query(int ls, int rs){ int k = log2[rs - ls + 1]; int ca = dp[ls][k]; int cb = dp[rs - (1 << k) + 1][k]; if(vs[ca] < vs[cb]) return ca; else return cb; } }rmq; void dfs(int v, int p, int d, int& k){ //printf("v = %d\n", v); id[v] = k; vs[k] = v; dep[k ++] = d; for(int i = 0; i < G[v].size(); i ++){ if(G[v][i] != p){ dfs(G[v][i], v, d + 1, k); vs[k] = v; dep[k ++] = d; } } } void init(int V){ int k = 1; dfs(rt, -1, 0, k); rmq.init(V * 2 - 1); } int lca(int u, int v){ return vs[rmq.query(min(id[u], id[v]), max(id[u], id[v]))]; } void print(){ for(int i = 0; i < 2 * n; i ++) printf("vs[%d] = %d\n", i, vs[i]); for(int i = 1; i <= n; i ++) printf("id[%d] = %d\n", i, id[i]); for(int i = 1; i <= n; i ++) printf("dep[%d] = %d\n", i, dep[i]); for(int i = 1; i <= n; i ++) printf("sum[%d] = %d\n", i, sum[i]); } int main(){ while(scanf("%d%d", &n, &q) == 2){ for(int i = 1; i <= n; i ++)G[i].clear(); for(int i = 1; i < n; i ++){ scanf("%d%d", &a, &b); G[a].push_back(b); G[b].push_back(a); } rt = 1; init(n); //print(); while(q --){ scanf("%d%d", &a, &b); printf("%d\n", lca(a, b)); } } return 0; }
LCA的应用
LCA可以用来求树上的两个顶点之间的权值和,让任意一个点作为根节点,设sum[u]为顶点rt到u的权值和,那么u到v的权值和就是sum[u] - sum[lca(u, v)] + sum[v] - sum[lca(u, v)]。
来看一道题: 传送门
附上两种方法的代码:
1.倍增法:
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; const int N = 40000 + 5; vector<int> G[N]; vector<int> E[N]; int rt, a, b, c, q, n, m, T; int dsu[40][N], dep[N], sum[N]; void dfs(int v, int p, int d){ dsu[0][v] = p; dep[v] = d; for(int i = 0; i < G[v].size(); i ++) if(G[v][i] != p){ sum[G[v][i]] = sum[v] + E[v][i]; dfs(G[v][i], v, d + 1); } } void init(){ dfs(rt, -1, 0); for(int k = 0; k + 1 < 32; k ++){ for(int v = 1; v <= n; v ++){ if(dsu[k][v] < 0)dsu[k + 1][v] = -1; else dsu[k + 1][v] = dsu[k][dsu[k][v]]; } } } int lca(int u, int v){ if(dep[u] > dep[v]) swap(u, v); for(int k = 0; k < 32; k ++){ if((dep[v] - dep[u]) >> k & 1) v = dsu[k][v]; } if(u == v)return u; for(int k = 31; k >= 0; k --){ if(dsu[k][u] != dsu[k][v]){ u = dsu[k][u]; v = dsu[k][v]; } } return dsu[0][u]; } int main(){ scanf("%d", &T); while(T--){ scanf("%d%d", &n, &q); for(int i = 1; i <= N; i ++)G[i].clear(), E[i].clear(); for(int i = 1; i < n; i ++){ scanf("%d%d%d", &a, &b, &c); G[a].push_back(b); G[b].push_back(a); E[a].push_back(c); E[b].push_back(c); } rt = 1; init(); while(q --){ scanf("%d%d", &a, &b); c = lca(a, b); printf("%d\n", sum[a] + sum[b] - 2 * sum[c]); } } return 0; }
2RMQ + dfs序:
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; const int N = 100000 + 5; vector<int>G[N]; vector<int>E[N]; int rt, n, m, a, b, c, q; int vs[N * 2 - 1], dep[N * 2 - 1], id[N], sum[N]; int dp[N][40]; struct RMQ{ int log2[N]; void init(int n){ log2[0] = -1; for(int i = 1; i <= n; i ++)log2[i] = log2[i >> 1] + 1; for(int i = 1; i <= n; i ++)dp[i][0] = vs[i]; for(int j = 1; j <= log2[n]; j ++){ for(int i = 1; i + (1 << j) <= n + 1; i ++){ int ca = dp[i][j - 1]; int cb = dp[i + (1 << j)][j - 1]; if(vs[ca] < vs[cb]) dp[i][j] = ca; else dp[i][j] = cb; } } } int query(int ls, int rs){ int k = log2[rs - ls + 1]; int ca = dp[ls][k]; int cb = dp[rs - (1 << k) + 1][k]; if(vs[ca] < vs[cb]) return ca; else return cb; } }rmq; void dfs(int v, int p, int d, int& k){ //printf("v = %d\n", v); id[v] = k; vs[k] = v; dep[k ++] = d; for(int i = 0; i < G[v].size(); i ++){ if(G[v][i] != p){ sum[G[v][i]] = sum[v] + E[v][i]; dfs(G[v][i], v, d + 1, k); vs[k] = v; dep[k ++] = d; } } } void init(int V){ int k = 1; sum[0] = sum[1] = 0; dfs(rt, -1, 0, k); rmq.init(V * 2 - 1); } int lca(int u, int v){ return vs[rmq.query(min(id[u], id[v]), max(id[u], id[v]))]; } void print(){ for(int i = 0; i < 2 * n; i ++) printf("vs[%d] = %d\n", i, vs[i]); for(int i = 1; i <= n; i ++) printf("id[%d] = %d\n", i, id[i]); for(int i = 1; i <= n; i ++) printf("dep[%d] = %d\n", i, dep[i]); for(int i = 1; i <= n; i ++) printf("sum[%d] = %d\n", i, sum[i]); } int main(){ int T; scanf("%d", &T); while(T--){ scanf("%d%d", &n, &q); for(int i = 1; i <= n; i ++)G[i].clear(), E[i].clear(); for(int i = 1; i < n; i ++){ scanf("%d%d%d", &a, &b, &c); G[a].push_back(b); G[b].push_back(a); E[a].push_back(c); E[b].push_back(c); } rt = 1; init(n); //print(); while(q --){ scanf("%d%d", &a, &b); c = lca(a, b); //printf("a = %d, b = %d, c = %d\n", a, b, c); printf("%d\n", sum[a] + sum[b] - 2 * sum[c]); } } return 0; }