并查集

并查集

  1. 一开始每个元素都以自己为一个集合
  2. find(i):查找 i 所在集合的代表元素,代表元素代表了 i 所在的集合
  3. isSameSet(a, b):判断 a、b 是否在同一个集合里
  4. union(a, b):将 a、b 所在的两个集合合并

并查集的实现

#include <iostream>
#include <vector>
#include <stack>

using namespace std;

int len;
// father[i] 为 i 所在集合的代表元素
vector<int> father;
// 记录代表元素所在集合的大小
vector<int> s1ze;
// 暂存 find 过程中经过的元素
stack<int> stk;

void build() {
    father.resize(len);
    s1ze.resize(len);
    for (int i = 0; i < len; ++i) {
        // 初始状态以自己为集合
        father[i] = i;
        s1ze[i] = 1;
    }
}

int find(int x) {
    while (x != father[x]) {
        // 往上找代表元素的过程中记录经过的节点
        stk.emplace(x);
        x = father[x];
    }
    // 路径压缩
    while (!empty(stk)) {
        father[stk.top()] = x;
        stk.pop();
    }
    return x;
}

bool isSameSet(int a, int b) {
    return find(a) == find(b);
}

void un1on(int a, int b) {
    // 找到元素所在集合的代表元素
    int fatherOfA = find(a);
    int fatherOfB = find(b);
    // 已经在同一个集合就返回
    if (fatherOfA == fatherOfB) return;
    // 小挂大:集合元素少的把代表元素挂在集合元素多的代表元素上
    if (s1ze[fatherOfA] > s1ze[fatherOfB]) {
        father[fatherOfB] = fatherOfA;
        s1ze[fatherOfA] += s1ze[fatherOfB];
    } else {
        father[fatherOfA] = fatherOfB;
        s1ze[fatherOfB] += s1ze[fatherOfA];
    }
}

int main() {
    int N, M;
    cin >> N >> M;
    len = N + 1;

    build();

    for (int i = 0, opt, a, b; i < M; ++i) {
        cin >> opt >> a >> b;
        if (opt == 1) {
            if (isSameSet(a, b)) {
                cout << "Yes" << endl;
            } else {
                cout << "No" << endl;
            }
        } else {
            un1on(a, b);
        }
    }
}

P3367 【模板】并查集

#include <iostream>
#include <vector>

using namespace std;

int len;
// father[i] 为 i 所在集合的代表元素
vector<int> father;

void build() {
    father.resize(len);
    // 初始状态以自己为集合
    for (int i = 0; i < len; ++i)
        father[i] = i;
}

int find(int x) {
    // 路径压缩
    if (x != father[x])
        father[x] = find(father[x]);
    return father[x];
}

bool isSameSet(int a, int b) {
    return find(a) == find(b);
}

void un1on(int a, int b) {
    // a 所在集合并入 b 所在集合
    father[find(a)] = find(b);
}

int main() {
    int N, M;
    cin >> N >> M;
    len = N + 1;

    build();

    for (int i = 0, opt, a, b; i < M; ++i) {
        cin >> opt >> a >> b;
        if (opt == 2) {
            if (isSameSet(a, b)) {
                cout << "Y" << endl;
            } else {
                cout << "N" << endl;
            }
        } else {
            un1on(a, b);
        }
    }
}

765. 情侣牵手

#include <iostream>
#include <vector>

using namespace std;

class Solution {
public:
    // 下标为情侣组号,组号为情侣编号/2
    vector<int> father;
    // 记录集合总数
    int sets;

    void build(int len) {
        father.resize(len);
        for (int i = 0; i < len; ++i) father[i] = i;
        sets = len;
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    bool isSameSet(int a, int b) {
        return find(a) == find(b);
    }

    void un1on(int a, int b) {
        int fa = find(a);
        int fb = find(b);
        if (fa == fb) return;
        father[fa] = fb;
        // 集合总数减一
        sets--;
    }

    int minSwapsCouples(vector<int> &row) {
        int n = row.size();
        build(n / 2);

        // 所在的情侣组合并
        for (int i = 0; i < n; i += 2)
            un1on(row[i] / 2, row[i + 1] / 2);
        // 每个情侣组集合有 a 组,则这个集合至少需要 a - 1 次交换
        // 一共有 n / 2 个情侣组(每组两个人),总的交换次数就是 n / 2 - 集合总数
        return n / 2 - sets;
    }
};

839. 相似字符串组

#include <iostream>
#include <vector>

using namespace std;

class Solution {
public:
    vector<int> father;
    // 集合总数
    int sets;

    void build(int len) {
        father.resize(len);
        for (int i = 0; i < len; ++i)
            father[i] = i;
        sets = len;
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    bool isSameSet(int a, int b) {
        return find(a) == find(b);
    }

    void un1on(int a, int b) {
        int fa = find(a);
        int fb = find(b);
        if (fa == fb) return;
        father[fa] = fb;
        sets--;
    }

    bool isSimilar(string s1, string s2) {
        int len = s1.size();
        int diff = 0;
        // 判断相同位置元素不同的地方有几个,为 0 或者 2 才是相似的异构词
        for (int i = 0; i < len && diff < 3; ++i)
            if (s1[i] != s2[i]) diff++;
        return diff == 0 || diff == 2;
    }

    int numSimilarGroups(vector<string> &strs) {
        int len = strs.size();
        build(len);

        for (int i = 0; i < len; ++i)
            for (int j = i + 1; j < len; ++j)
                // 相似就合并
                if (isSimilar(strs[i], strs[j]))
                    un1on(i, j);
        return sets;
    }
};

200. 岛屿数量

#include <iostream>
#include <vector>

using namespace std;

class Solution {
public:
    int rows;
    int columns;
    vector<int> father;
    int sets = 0;

    // 坐标映射为一维下标
    int getIndex(int a, int b) {
        return a * columns + b;
    }

    void build(vector<vector<char>> &grid) {
        father.resize(rows * columns + 1);
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < columns; ++j) {
                // 只对陆地初始化
                if (grid[i][j] == '1') {
                    int index = getIndex(i, j);
                    father[index] = index;
                    sets++;
                }
            }
        }
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    bool isSameSet(int a, int b) {
        return find(a) == find(b);
    }

    void un1on(int a, int b) {
        int fa = find(a);
        int fb = find(b);
        if (fa == fb) return;
        father[fa] = fb;
        sets--;
    }

    int numIslands(vector<vector<char>> &grid) {
        rows = grid.size();
        columns = grid[0].size();

        build(grid);

        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < columns; ++j) {
                // 当前位置不是陆地就跳过,不需要和其他集合合并
                if (grid[i][j] != '1') continue;
                int index = getIndex(i, j);
                // 上方是陆地
                if (i > 0 && grid[i - 1][j] == '1')
                    un1on(index, getIndex(i - 1, j));
                // 左侧是陆地
                if (j > 0 && grid[i][j - 1] == '1')
                    un1on(index, getIndex(i, j - 1));
            }
        }
        return sets;
    }
};
// todo 洪水填充

947. 移除最多的同行或同列石头

#include <iostream>
#include <vector>
#include <unordered_map>

using namespace std;

class Solution {
public:

    // 记录第一次遇到的石头的编号,以行列为 key,石头编号为 value
    unordered_map<int, int> rowFirst;
    unordered_map<int, int> colFirst;
    vector<int> father;
    int sets;

    void build(int len) {
        rowFirst.clear();
        colFirst.clear();
        father.resize(len);
        for (int i = 0; i < len; ++i)
            father[i] = i;
        sets = len;
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    void un1on(int a, int b) {
        int fa = find(a);
        int fb = find(b);
        if (fa == fb) return;
        father[fa] = fb;
        sets--;
    }

    int removeStones(vector<vector<int>> &stones) {
        // 石头数量
        int len = stones.size();
        build(len);

        for (int i = 0; i < len; ++i) {
            int row = stones[i][0];
            int col = stones[i][1];
            // 所有在行列上有关联的都合并到一起
            if (rowFirst.find(row) == rowFirst.end()) {
                // 该行第一次出现
                rowFirst[row] = i;
            } else {
                // 和之前在这行上第一次出现的石头合并到一个集合
                un1on(i, rowFirst[row]);
            }
            if (colFirst.find(col) == colFirst.end()) {
                // 该行第一次出现
                colFirst[col] = i;
            } else {
                // 和之前在这行上第一次出现的石头合并到一个集合
                un1on(i, colFirst[col]);
            }
        }
        // 最少剩 sets 个石头,最多移除 len - sets 个石头
        return len - sets;
    }
};

2092. 找出知晓秘密的所有专家

#include <iostream>
#include <vector>
#include <stack>
#include <algorithm>

using namespace std;

class Solution {
public:
    vector<int> father;
    // 标记代表元素所在集合是否知道秘密
    vector<bool> secret;
    stack<int> stk;

    void build(int n, int firstPerson) {
        father.resize(n);
        secret.resize(n, false);
        for (int i = 0; i < n; ++i)
            father[i] = i;
        // 初始知道秘密的两个人
        secret[0] = true;
        secret[firstPerson] = true;
        // 并入一个集合
        father[firstPerson] = 0;
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    void un1on(int a, int b) {
        int fa = find(a);
        int fb = find(b);
        if (fa == fb) return;
        father[fa] = fb;
        // 传递是否知道秘密
        secret[fb] = secret[fb] | secret[fa];
    }

    // 判断 i 是否知道秘密
    bool knowSecret(int i) {
        return secret[find(i)];
    }

    static bool cmp(vector<int> arr1, vector<int> arr2) {
        return arr1[2] < arr2[2];
    }

    vector<int> findAllPeople(int n, vector<vector<int>> &meetings, int firstPerson) {
        // 按照会议时间排序
        sort(begin(meetings), end(meetings), cmp);

        // 初始化
        build(n, firstPerson);

        // 会议总数
        int len = meetings.size();
        int i = 0;
        int time = meetings[0][2];
        while (i < len) {
            // 同一个时间开会的,关联到一个集合
            while (i < len && meetings[i][2] == time) {
                // 这一时间开会的所有人入栈,为后续分理出不知道秘密的人做准备
                stk.emplace(meetings[i][0]);
                stk.emplace(meetings[i][1]);
                // 相关联,并入一个集合
                un1on(meetings[i][0], meetings[i][1]);
                i++;
            }
            while (!stk.empty()) {
                // 不知道秘密人的要从相互关联却不知道秘密的集合中分离出来
                if (!knowSecret(stk.top())) father[stk.top()] = stk.top();
                stk.pop();
            }
            if (i < len) time = meetings[i][2];
        }

        vector<int> res;
        // 加入所有知道秘密的人
        for (int j = 0; j < n; ++j)
            if (knowSecret(j)) res.emplace_back(j);
        return res;
    }
};

2421. 好路径的数目

#include <iostream>
#include <vector>
#include <stack>
#include <algorithm>

using namespace std;


class Solution {
public:
    // 集合中的最大值作为代表节点
    vector<int> father;
    // 记录顶点 i 所在集合的最大值和最大值的个数
    vector<int> cnt;

    void build(int n, vector<int> &vals) {
        father.resize(n);
        cnt.resize(n);
        for (int i = 0; i < n; ++i) {
            father[i] = i;
            cnt[i] = 1;
        }
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    // 返回产生好路径的条数
    int un1on(int a, int b, vector<int> &vals) {
        int fa = find(a);
        int fb = find(b);
        int path = 0;
        // 更新标签信息
        if (vals[fa] == vals[fb]) {
            // 两个集合的最大值一样才会出现好路径
            path = cnt[fa] * cnt[fb];
            cnt[fb] += cnt[fa];
            father[fa] = fb;
        } else if (vals[fa] > vals[fb]) {
            father[fb] = fa;
        } else if (vals[fa] < vals[fb]) {
            father[fa] = fb;
        }
        return path;
    }

    int numberOfGoodPaths(vector<int> &vals, vector<vector<int>> &edges) {
        int res = vals.size();
        // vals[i] 表示顶点 i 的值,edges[i][0]、edges[i][1] 为边的两个顶点
        // 根据边的两个顶点的较大值进行排序
        sort(begin(edges), end(edges),
             [&vals](vector<int> v1, vector<int> v2) {
                 return max(vals[v1[0]], vals[v1[1]]) < max(vals[v2[0]], vals[v2[1]]);
             });

        build(res, vals);

        for (auto &edge: edges)
            res += un1on(edge[0], edge[1], vals);

        return res;
    }
};

928. 尽量减少恶意软件的传播 II

#include <iostream>
#include <vector>
#include <stack>
#include <algorithm>

using namespace std;

class Solution {
public:
    vector<bool> virus;
    // 每个病毒节点删除后能就回来的节点总数
    vector<int> counts;
    // infect[a] 为 -1 表示没有病毒节点,大于等于 0 表示病毒节点就是 infect[a]
    // 为 -2 表示有多个病毒节点,删掉一个病毒节点并不能拯救这个集合
    vector<int> infect;
    vector<int> father;
    // 记录集合的大小,集合只存普通节点,不存放病毒节点
    vector<int> size;

    void build(int n, vector<int> &initial) {
        // 标记病毒节点
        virus.resize(n, false);
        for (auto &i: initial)
            virus[i] = true;

        counts.resize(n, 0);
        infect.resize(n, -1);
        size.resize(n, 1);
        father.resize(n);
        for (int i = 0; i < n; ++i)
            father[i] = i;
    }

    int find(int i) {
        if (i != father[i])
            father[i] = find(father[i]);
        return father[i];
    }

    void un1on(int a, int b) {
        int fa = find(a);
        int fb = find(b);
        if (fa == fb) return;
        father[fa] = fb;
        size[fb] += size[fa];
    }

    int minMalwareSpread(vector<vector<int>> &graph, vector<int> &initial) {
        int len = graph.size();
        build(len, initial);

        // 把有连接的普通节点合并
        for (int i = 0; i < len; ++i)
            for (int j = 0; j < len; ++j)
                if (graph[i][j] == 1 && !virus[i] && !virus[j])
                    un1on(i, j);
        // 把病毒四周的直接感染的节点所在的集合标记上病毒编号
        for (auto &sick: initial) {
            for (int neighbour = 0; neighbour < len; ++neighbour) {
                // 遍历病毒的所有邻居
                if (sick != neighbour && !virus[neighbour] && graph[sick][neighbour] == 1) {
                    // 找到邻居所在集合
                    int fn = find(neighbour);
                    if (infect[fn] == -1) {
                        infect[fn] = sick;
                    } else if (infect[fn] != -2 && infect[fn] != sick) {
                        // 已有其他病毒感染过这个集合,多个病毒感染的集合没得救
                        infect[fn] = -2;
                    }
                }
            }
        }

        // 对所有只被一个病毒感染的集合进行统计,累加病毒节点的感染节点总数
        for (int i = 0; i < len; ++i)
            if (i == find(i) && infect[i] >= 0)
                counts[infect[i]] += size[i];

        // 按照病毒序号排序,病毒感染总数一样时返回编号小的
        sort(begin(initial), end(initial));
        int res = initial[0];
        for (auto &sick: initial) {
            if (counts[sick] > counts[res])
                res = sick;
        }
        return res;
    }
};
posted @ 2024-09-07 21:37  n1ce2cv  阅读(7)  评论(0编辑  收藏  举报