LCA算法实现

参考问题:洛谷P3379 【模板】最近公共祖先(LCA):https://www.luogu.com.cn/problem/P3379

暴力解法,dfs一下,求得所有点的深度,然后每当我们要求 \(x\)\(y\) 的 LCA 的时候,我们就循环的去判断:

  • 如果当前 \(x\) 的深度大于 \(y\) 的深度,则令 \(x\) 变成 \(x\) 的父亲;
  • 否则,如果当前 \(y\) 的深度大于 \(x\) 的深度,则令 \(y\) 变成 \(y\) 的父亲;
  • 否则(说明当前 \(x\)\(y\) 处在同一深度),如果 \(x \ne y\),则令 \(x\) 变成 \(x\) 的父亲(或者令 \(y\) 变成 \(y\)的父亲,均可);
  • 否则(说明 \(x=y\)),返回 \(x\)(或者 \(y\),均可)。

这种暴力的解法的时间复杂度达到了 \(O(n \cdot m)\),实现代码如下:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 500050;
int n, m, rt, pa[maxn], dep[maxn];
vector<int> g[maxn];
void dfs(int u, int d) {
    dep[u] = d;
    int sz = g[u].size();
    for (int i = 0; i < sz; i ++) {
        int v = g[u][i];
        if (v == pa[u]) continue;
        pa[v] = u;
        dfs(v, d+1);
    }
}
int query(int x, int y) {
    while (x != y) {
        if (dep[x] < dep[y]) y = pa[y];
        else x = pa[x];
    }
    return x;
}
int main() {
    cin >> n >> m >> rt;
    for (int i = 1; i < n; i ++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(rt, 1);
    while (m --) {
        int x, y;
        cin >> x >> y;
        cout << query(x, y) << endl;
    }
    return 0;
}

借助于倍增思想可以将上述算法优化到 \(O(m \cdot log n)\)
下面的实现中:
\(pa[u][i]\) 表示深度比 \(u\)\(2^i\) 的祖先节点的编号,如果这个节点不存在,则该值对应为 \(rt\)(根节点)。

实现代码如下:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 500050;
int n, m, rt, pa[maxn][21], dep[maxn];
vector<int> g[maxn];
void dfs(int u, int p) {
    dep[u] = dep[p] + 1;
    pa[u][0] = p;
    for (int i = 1; (1<<i) <= dep[u]; i ++)
        pa[u][i] = pa[ pa[u][i-1] ][i-1];
    int sz = g[u].size();
    for (int i = 0; i < sz; i ++) {
        int v = g[u][i];
        if (v == p) continue;
        dfs(v, u);
    }
}
int lca(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 20; i >= 0; i --) {
        if (dep[ pa[x][i] ] >= dep[y]) x = pa[x][i];
        if (x == y) return x;
    }
    for (int i = 20; i >= 0; i --) {
        if (pa[x][i] != pa[y][i]) {
            x = pa[x][i];
            y = pa[y][i];
        }
    }
    return pa[x][0];
}
int main() {
    cin >> n >> m >> rt;
    for (int i = 1; i < n; i ++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(rt, 0);
    while (m --) {
        int x, y;
        cin >> x >> y;
        cout << lca(x, y) << endl;
    }
    return 0;
}

然而不知为何在这道题仍然是70分超时。(可能是由于对数影响)

Tarjan求LCA

参考资料:https://riteme.site/blog/2016-2-1/lca.html

注:写的非常好。

实现代码如下:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 500050, maxm = maxn*2;
struct Edge {
    int v, nxt;
    Edge() {};
    Edge(int _v, int _nxt) { v = _v; nxt = _nxt; }
} edge[maxm];
int n, m, rt, head[maxn], ecnt;
int f[maxn];
void init() {
    // 图部分
    memset(head, -1, sizeof(int)*(n+1));
    ecnt = 0;
    // 并查集部分
    for (int i = 1; i <= n; i ++) f[i] = i;
}
void addedge(int u, int v) {
    edge[ecnt] = Edge(v, head[u]); head[u] = ecnt ++;
    edge[ecnt] = Edge(u, head[v]); head[v] = ecnt ++;
}
struct Query {
    int id, v;  // id表示问题编号,v表示另一个点编号
    Query() {};
    Query(int _id, int _v) { id = _id; v = _v; }
};
vector<Query> query[maxn];
void addquery(int x, int y, int id) {
    query[x].push_back(Query(id, y));
    query[y].push_back(Query(id, x));
}
int Find(int x) {
    return f[x] == x ? x : f[x] = Find(f[x]);
}
void Union(int x, int y) {  // 这里是将y集合所在的子树合并为x所在根节点的子树
    int a = Find(x), b = Find(y);
    if (a != b) f[b] = a;
}
int vis[maxn], ans[maxn];
void tarjan(int u) {
    vis[u] = 1;
    for (int i = head[u]; i != -1; i = edge[i].nxt) {
        int v = edge[i].v;
        if (vis[v]) continue;
        tarjan(v);
        f[v] = u;   // Union(u, v);
    }
    int sz = query[u].size();
    for (int i = 0; i < sz; i ++) {
        int id = query[u][i].id, v = query[u][i].v;
        if (vis[v] == 2) ans[id] = Find(v);
    }
    vis[u] = 2;
}
int main() {
    scanf("%d%d%d", &n, &m, &rt);
    init();
    for (int i = 1; i < n; i ++) {
        int u, v;
        scanf("%d%d", &u, &v);
        addedge(u, v);
    }
    for (int i = 1; i <= n; i ++) {
        int x, y;
        scanf("%d%d", &x, &y);
        if (x == y) ans[i] = x;
        else addquery(x, y, i);
    }
    tarjan(rt);
    for (int i = 1; i <= n; i ++) printf("%d\n", ans[i]);
    return 0;
}
posted @ 2020-05-21 00:13  quanjun  阅读(197)  评论(0编辑  收藏  举报