并查集

并查集被很多OIer认为是最简洁而优雅的数据结构之一,主要用于解决一些元素分组的问题。它管理一系列不相交的集合,并支持两种操作:

  • 合并(Union):把两个不相交的集合合并为一个集合。

  • 查询(Find):查询两个元素是否在同一个集合中。

并查集的引入

并查集的重要思想在于,用集合中的一个元素代表集合。

初始化

int fa[MAXN];
inline void init(int n)
{
    for (int i = 1; i <= n; ++i)
        fa[i] = i;
}

假如有编号为1, 2, 3, ..., n的n个元素,我们用一个数组fa[ ]来存储每个元素的父节点(因为每个元素有且只有一个父节点,所以这是可行的)。一开始,我们先将它们的父节点设为自己。

查询

int find(int x)
{
    if(fa[x] == x)
        return x;
    else
        return find(fa[x]);
}

我们用递归的写法实现对代表元素的查询:一层一层访问父节点,直至根节点(根节点的标志就是父节点是本身)。要判断两个元素是否属于同一个集合,只需要看它们的根节点是否相同即可。

合并

inline void merge(int i, int j)
{
    fa[find(i)] = find(j);
}

合并操作也是很简单的,先找到两个集合的代表元素,然后将前者的父节点设为后者即可。当然也可以将后者的父节点设为前者,这里暂时不重要。本文末尾会给出一个更合理的比较方法。

路径压缩

最简单的并查集效率是比较低的。

合并(路径压缩)

int find(int x)
{
    if(x == fa[x])
        return x;
    else{
        fa[x] = find(fa[x]);  //父节点设为根节点
        return fa[x];         //返回父节点
    }
}

以上代码常常简写为一行:

int find(int x)
{
    return x == fa[x] ? x : (fa[x] = find(fa[x]));
}

注意赋值运算符=的优先级没有三元运算符?:高,这里要加括号。

路径压缩优化后,并查集的时间复杂度已经比较低了,绝大多数不相交集合的合并查询问题都能够解决。然而,对于某些时间卡得很紧的题目,我们还可以进一步优化。

按秩合并

有些人可能有一个误解,以为路径压缩优化后,并查集始终都是一个菊花图(只有两层的树的俗称)。但其实,由于路径压缩只在查询时进行,也只压缩一条路径,所以并查集最终的结构仍然可能是比较复杂的。我们应该把简单的树往复杂的树上合并,而不是相反。因为这样合并后,到根节点距离变长的节点个数比较少。

我们用一个数组rank[]记录每个根节点对应的树的深度(如果不是根节点,其rank相当于以它作为根节点的子树的深度)。一开始,把所有元素的rank(秩)设为1。合并时比较两个根节点,把rank较小者往较大者上合并。

路径压缩和按秩合并如果一起使用,时间复杂度接近 [公式] ,但是很可能会破坏rank的准确性。

初始化(按秩合并)

inline void init(int n)
{
    for (int i = 1; i <= n; ++i)
    {
        fa[i] = i;
        rank[i] = 1;
    }
}

合并(按秩合并)

inline void merge(int i, int j)
{
    int x = find(i), y = find(j);    //先找到两个根节点
    if (rank[x] <= rank[y])
        fa[x] = y;
    else
        fa[y] = x;
    if (rank[x] == rank[y] && x != y)
        rank[y]++;                   //如果深度相同且根节点不同,则新的根节点的深度+1
}

例题 1:[NOI2015]程序自动分析

在实现程序自动分析的过程中,常常需要判定一些约束条件是否能被同时满足。

考虑一个约束满足问题的简化版本:假设 \(x_1,x_2,x_3,…\) 代表程序中出现的变量,给定 \(n\) 个形如 \(x_i=x_j\)\(x_i≠x_j\) 的变量相等/不等的约束条件,请判定是否可以分别为每一个变量赋予恰当的值,使得上述所有约束条件同时被满足。

例如,一个问题中的约束条件为:\(x_1=x_2,x_2=x_3,x_3=x_4,x_1≠x_4\),这些约束条件显然是不可能同时被满足的,因此这个问题应判定为不可被满足。

现在给出一些约束满足问题,请分别对它们进行判定。

输入格式
输入文件的第 \(1\) 行包含 \(1\) 个正整数 \(t\),表示需要判定的问题个数,注意这些问题之间是相互独立的。

对于每个问题,包含若干行:

\(1\) 行包含 \(1\) 个正整数 \(n\),表示该问题中需要被满足的约束条件个数。

接下来 \(n\) 行,每行包括 \(3\) 个整数 \(i,j,e\),描述 \(1\) 个相等/不等的约束条件,相邻整数之间用单个空格隔开。若 \(e=1\),则该约束条件为 \(x_i=x_j\);若 \(e=0\),则该约束条件为 \(x_i≠x_j\)

输出格式
输出文件包括 \(t\) 行。

输出文件的第 \(k\) 行输出一个字符串 YES 或者 NO,YES 表示输入中的第 \(k\) 个问题判定为可以被满足,NO 表示不可被满足。

数据范围
\(1≤n≤105\)
\(1≤i,j≤109\)

输入样例:
2
2
1 2 1
1 2 0
2
1 2 1
2 1 1

输出样例:
NO
YES

代码

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_map>

using namespace std;

const int N = 2000010;

int n, m;
int p[N];
unordered_map<int, int> S;

struct Query
{
    int x, y, e;
}query[N];

int get(int x)
{
    if (S.count(x) == 0) S[x] = ++ n;
    return S[x];
}

int find(int x)
{
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    int T;
    scanf("%d", &T);
    while (T -- )
    {
        n = 0;
        S.clear();
        scanf("%d", &m);
        for (int i = 0; i < m; i ++ )
        {
            int x, y, e;
            scanf("%d%d%d", &x, &y, &e);
            query[i] = {get(x), get(y), e};
        }

        for (int i = 1; i <= n; i ++ ) p[i] = i;

        // 合并所有相等约束条件
        for (int i = 0; i < m; i ++ )
            if (query[i].e == 1)
            {
                int pa = find(query[i].x), pb = find(query[i].y);
                p[pa] = pb;
            }

        // 检查所有不等条件
        bool has_conflict = false;
        for (int i = 0; i < m; i ++ )
            if (query[i].e == 0)
            {
                int pa = find(query[i].x), pb = find(query[i].y);
                if (pa == pb)
                {
                    has_conflict = true;
                    break;
                }
            }

        if (has_conflict) puts("NO");
        else puts("YES");
    }

    return 0;
}

例题 2:[POJ1733]奇偶游戏

小 A 和小 B 在玩一个游戏。

首先,小 A 写了一个由 \(0\)\(1\) 组成的序列 \(S\),长度为 \(N\)

然后,小 B 向小 A 提出了 \(M\) 个问题。

在每个问题中,小 B 指定两个数 \(l\)\(r\),小 A 回答 \(S[l∼r]\) 中有奇数个 \(1\) 还是偶数个 \(1\)

机智的小 B 发现小 A 有可能在撒谎。

例如,小 A 曾经回答过 \(S[1∼3]\) 中有奇数个 \(1\)\(S[4∼6]\) 中有偶数个 \(1\),现在又回答 \(S[1∼6]\) 中有偶数个 \(1\),显然这是自相矛盾的。

请你帮助小 B 检查这 \(M\) 个答案,并指出在至少多少个回答之后可以确定小 A 一定在撒谎。

即求出一个最小的 \(k\),使得 01 序列 \(S\) 满足第 \(1∼k\) 个回答,但不满足第 \(1∼k+1\) 个回答。

输入格式
第一行包含一个整数 \(N\),表示 01 序列长度。

第二行包含一个整数 \(M\),表示问题数量。

接下来 \(M\) 行,每行包含一组问答:两个整数 \(l\)\(r\),以及回答 even 或 odd,用以描述 \(S[l∼r]\) 中有偶数个 \(1\) 还是奇数个 \(1\)

输出格式
输出一个整数 \(k\),表示 01 序列满足第 \(1∼k\) 个回答,但不满足第 \(1∼k+1\) 个回答,如果 01 序列满足所有回答,则输出问题总数量。

数据范围
\(N≤109,M≤5000\)

输入样例:
10
5
1 2 even
3 4 odd
5 6 even
1 6 even
7 10 odd

输出样例:
3

代码

带边权写法

#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_map>

using namespace std;

const int N = 20010;

int n, m;
int p[N], d[N];
unordered_map<int, int> S;

int get(int x)
{
    if (S.count(x) == 0) S[x] = ++ n;
    return S[x];
}

int find(int x)
{
    if (p[x] != x)
    {
        int root = find(p[x]);
        d[x] += d[p[x]];
        p[x] = root;
    }
    return p[x];
}

int main()
{
    cin >> n >> m;
    n = 0;

    for (int i = 0; i < N; i ++ ) p[i] = i;

    int res = m;
    for (int i = 1; i <= m; i ++ )
    {
        int a, b;
        string type;
        cin >> a >> b >> type;
        a = get(a - 1), b = get(b);

        int t = 0;
        if (type == "odd") t = 1;

        int pa = find(a), pb = find(b);
        if (pa == pb)
        {
            if (((d[a] + d[b]) % 2 + 2) % 2 != t)
            {
                res = i - 1;
                break;
            }
        }
        else
        {
            p[pa] = pb;
            d[pa] = d[a] ^ d[b] ^ t;
        }
    }

    cout << res << endl;

    return 0;
}

扩展域写法

#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_map>

using namespace std;

const int N = 40010, Base = N / 2;

int n, m;
int p[N];
unordered_map<int, int> S;

int get(int x)
{
    if (S.count(x) == 0) S[x] = ++ n;
    return S[x];
}

int find(int x)
{
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    cin >> n >> m;
    n = 0;

    for (int i = 0; i < N; i ++ ) p[i] = i;

    int res = m;
    for (int i = 1; i <= m; i ++ )
    {
        int a, b;
        string type;
        cin >> a >> b >> type;
        a = get(a - 1), b = get(b);

        if (type == "even")
        {
            if (find(a + Base) == find(b))
            {
                res = i - 1;
                break;
            }
            p[find(a)] = find(b);
            p[find(a + Base)] = find(b + Base);
        }
        else
        {
            if (find(a) == find(b))
            {
                res = i - 1;
                break;
            }

            p[find(a + Base)] = find(b);
            p[find(a)] = find(b + Base);
        }
    }

    cout << res << endl;

    return 0;
}

例题 3:[NOI2002]银河英雄传说

有一个划分为 \(N\) 列的星际战场,各列依次编号为 \(1,2,…,N\)

\(N\) 艘战舰,也依次编号为 \(1,2,…,N\),其中第 \(i\) 号战舰处于第 \(i\) 列。

\(T\) 条指令,每条指令格式为以下两种之一:

1.M i j,表示让第 \(i\) 号战舰所在列的全部战舰保持原有顺序,接在第 \(j\) 号战舰所在列的尾部。
2.C i j,表示询问第 \(i\) 号战舰与第 \(j\) 号战舰当前是否处于同一列中,如果在同一列中,它们之间间隔了多少艘战舰。
现在需要你编写一个程序,处理一系列的指令。

输入格式
第一行包含整数 \(T\),表示共有 \(T\) 条指令。

接下来 \(T\) 行,每行一个指令,指令有两种形式:M i jC i j

其中 \(M\)\(C\) 为大写字母表示指令类型,\(i\)\(j\) 为整数,表示指令涉及的战舰编号。

输出格式
你的程序应当依次对输入的每一条指令进行分析和处理:

如果是 M i j 形式,则表示舰队排列发生了变化,你的程序要注意到这一点,但是不要输出任何信息;

如果是 C i j 形式,你的程序要输出一行,仅包含一个整数,表示在同一列上,第 \(i\) 号战舰与第 \(j\) 号战舰之间布置的战舰数目,如果第 \(i\) 号战舰与第 \(j\) 号战舰当前不在同一列上,则输出 −1。

数据范围
\(N≤30000,T≤500000\)

输入样例:
4
M 2 3
C 1 2
M 2 4
C 4 2

输出样例:
-1
1

代码

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 30010;

int m;
int p[N], size[N], d[N];

int find(int x)
{
    if (p[x] != x)
    {
        int root = find(p[x]);
        d[x] += d[p[x]];
        p[x] = root;
    }
    return p[x];
}

int main()
{
    scanf("%d", &m);

    for (int i = 1; i < N; i ++ )
    {
        p[i] = i;
        size[i] = 1;
    }

    while (m -- )
    {
        char op[2];
        int a, b;
        scanf("%s%d%d", op, &a, &b);
        if (op[0] == 'M')
        {
            int pa = find(a), pb = find(b);
            d[pa] = size[pb];
            size[pb] += size[pa];
            p[pa] = pb;
        }
        else
        {
            int pa = find(a), pb = find(b);
            if (pa != pb) puts("-1");
            else printf("%d\n", max(0, abs(d[a] - d[b]) - 1));
        }
    }

    return 0;
}

posted @ 2022-03-06 13:37  PassName  阅读(33)  评论(0编辑  收藏  举报