CF1242B. 0-1 MST 题解 并查集

题目链接:https://codeforces.com/problemset/problem/1242/B

题目大意:

给定一个包含 \(n(1 \le n \le 10^5)\) 个节点的完全图,图中有 \(m\) 条长度为 \(1\) 的边(\(0 \le m \le \min(\frac{n(n-1)}{2}, 10^5)\)),其余的边长度均为 \(0\)。求最小生成树长度。

解题思路(参考自 官方题解 ):

首先,如果只考虑图中长度为 \(0\) 的边的话:假设所有节点和长度为 \(0\) 的边构成的子图中有 \(k\) 个连通块,则最小生成树的长度为 \(k-1\)(因为这 \(k\) 个连通块之间需要 \(k-1\) 条长度为 \(1\) 的边将其连通)。因此,我们的主要目的就是去寻找存在多少个这样的连通块。

下面描述的是一个 \(O(m + n \log n)\) 的解法。

我们使用并查集维护零权值连通块,同时存储每个连通块的大小。然后从 \(1\)\(n\) 依次遍历每个节点 \(v\),并将节点 \(v\) 放到一个大小为 \(1\) 的连通块内。然后我们遍历与节点 \(v\) 邻接的所有权值为 \(1\) 的边 \(\{ u, v \}\)(要求 \(u \lt v\))。对于每一个零权值连通块,我们计算从这个连通块(即 \(u\) 所处的连通块)到节点 \(v\) 存在多少条边。如果边的数量小于 \(u\) 所处的连通块的大小,我们需要合并节点 \(u\)\(v\)(因为在连通块和 \(v\) 之间至少有一条权值为 \(0\) 的边)。否则,我们不能将这个连通块和节点 \(v\) 合并。最后,我们得到了零权值连通块的个数。

这个算法的时间复杂度是多少呢?整体来说,有 \(n\) 个连通块被创造出来了(因为每处理一个节点就会新建一个大小为 \(1\) 的连通块)。当我们将某个老的连通块同节点 \(v\) 进行合并的时候,连通块的数量减少了一个。所以合并操作最多进行 \(O(n)\) 次,而我们是使用并查集来进行合并的,所以这部分的时间复杂度为 \(O(n \log n)\)

如果一个旧的连通块与节点 \(v\) 不进行合并,则我们能证明这个旧的连通块节节点 \(v\) 之间至少存在一条边,所以这种情况等于边的数量 —— \(m\)。所以这部分的时间复杂度为 \(O(m)\)

所以总的时间复杂度为 \(O(n \log n + m)\)

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
int n, m, f[maxn], sz[maxn], ans;
vector<int> g[maxn], vec;
map<int, int> cnt;

void init() {
    for (int i = 1; i <= n; i++)
        f[i] = i, sz[i] = 1;
}

int find(int x) {
    return x == f[x] ? x : f[x] = find(f[x]);
}

void merge(int x, int y) {
    int a = find(x), b = find(y);
    if (a > b) swap(a, b);
    if (a != b) f[b] = a, sz[a] += sz[b];
}

int main() {
    cin >> n >> m;
    init();
    while (m--) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for (int u = 1; u <= n; u++) {
        cnt.clear();    // cnt[i]记录u与第i个集合有多少相连的边
        for (auto v : g[u])
            if (v < u)
                cnt[find(v)]++;
        for (auto v : vec) {
            int p = find(v);
            if (sz[p] > cnt[p])
                merge(p, u);
        }
        if (find(u) == u) vec.push_back(u);
    }
    for (int i = 1; i <= n; i++)
        if (find(i) == i)
            ans++;
    cout << ans - 1 << endl;
    return 0;
}
posted @ 2022-04-27 10:06  quanjun  阅读(124)  评论(0编辑  收藏  举报