「AcWing学习记录」并查集

并查集
1.将两个集合合并
2.询问两个元素是否在一个集合当中
时间复杂度近乎O(1)

基本原理
每个集合用一棵树来表示。树根的编号就是整个集合的编号,每个节点存储它的父节点,p[x]表示x的父节点

问题1:如何判断树根:if(p[x] == x)
问题2:如何求x的集合编号: while(p[x] != x) x = p[x];

AcWing 836. 合并集合

原题链接

#include<iostream>
#include<algorithm>

using namespace std;

const int N = 100010;

int n, m;
int p[N];

int find(int x) //返回x的祖宗节点 + 路径压缩
{
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    scanf("%d%d", &n, &m);

    for(int i = 1; i <= n; i++) p[i] = i;

    while(m--)
    {
        char op[2];
        int a, b;
        scanf("%s%d%d", op, &a, &b); //C++的一个技巧,scanf读取字符串会自动忽略空格和换行,因为不知道出题人会不会在行末加一个空格,所以用scanf的时候,不论是读字符还是字符串都用字符串

        if(op[0] == 'M') p[find(a)] = find(b);
        else
        {
            if(find(a) == find(b)) puts("Yes");
            else puts("No");
        }
    }

    return 0;
}

AcWing 837. 连通块中点的数量

原题链接

#include<iostream>
#include<algorithm>

using namespace std;

const int N = 100010;

int n, m;
int p[N], s[N];

int find(int x)
{
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

int main()
{
    scanf("%d%d", &n, &m);

    for(int i = 1; i <= n; i++)
    {
        p[i] = i;
        s[i] = 1;
    }

    while(m--)
    {
        char op[5];
        int a, b;
        scanf("%s", op);

        if(op[0] == 'C')
        {
            scanf("%d%d", &a, &b);
            if(find(a) == find(b)) continue;
            s[find(b)] += s[find(a)];
            p[find(a)] = find(b);
        }
        else if(op[1] == '1')
        {
            scanf("%d%d", &a, &b);
            if(find(a) == find(b)) puts("Yes");
            else puts("No");
        }
        else
        {
            scanf("%d", &a);
            printf("%d\n", s[find(a)]);
        }
    }

    return 0;
}

AcWing 240. 食物链

原题链接

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 50010;

int n, m;
int p[N], d[N];

int find(int x)
{
    if(p[x] != x)
    {
        int t = find(p[x]);
        d[x] += d[p[x]];
        p[x] = t;
    }
    return p[x];
}

int main()
{
    cin >> n >> m;

    for(int i = 1; i <= n; i++) p[i] = i;

    int res = 0;
    while(m--)
    {
        int t, x, y;
        scanf("%d%d%d", &t, &x, &y);

        if(x > n || y > n) res++;
        else
        {
            int px = find(x), py = find(y);
            if(t == 1)
            {
                if(px == py && (d[x] - d[y]) % 3) res++;
                else if(px != py)
                {
                    p[px] = py;
                    d[px] = d[y] - d[x];
                }
            }
            else
            {
                if(px == py && (d[x] - d[y] - 1) % 3) res++;
                else if(px != py)
                {
                    p[px] = py;
                    d[px] = d[y] - d[x] + 1;
                }
            }
        }
    }

    cout << res << endl;

    return 0;
}
posted @ 2023-02-11 21:12  恺雯  阅读(16)  评论(0编辑  收藏  举报