带权并查集

1.dis[x]

  维护自己到父亲节点的权值,在路径压缩(find)时更新即可。

 

2.find( )函数

  递推找到该节点的祖先节点,在回归过程中实现路径压缩。

  路径压缩可以理解成将该节点直接连到祖先节点上,当它儿子节点,同时权值dis也更新成与祖先节点的关系。
  路径上的每一个节点先更新该节点的父亲节点ff( = fa[x])与祖先节点的关系(递归),然后该节点和祖先节点的关系就可以表示为dis[x] + dis[ff];
 
复制代码
int find(int x) {
    if (fa[x] == x) return x;
    else {
        int ff = fa[x];
        fa[x] = find(fa[x]);//先保证前继节点的值更新好才能进行后续更新 
        dis[x] = (dis[x] + dis[ff]) % 3;//更新该节点与(祖先节点)的关系
    }
    return fa[x];
}
复制代码
 
 

3.merge( )函数

  作用:合并两个没有联系的独立集合。
 
  两个集合本来没有联系,每个集合内部的点存在联系,通过联系两个不同集合中独立的两个点使得两个集合建立关系。
  合并两个集合,实质就是建立两个集合中祖先节点之间的关系。合并后集合内部的点之间的关系通过各自和祖先节点之间的关系得到。详见dist函数。
 
  思路:先通过find实现路径压缩,将x和y节点直接连在各自祖先节点上,得出和祖先节点的关系dis[x], dis[y],合并祖先节点fax,fay,并通过向量得到fax和fay关系。
  
  如何确定fax,fay的关系:如果两个顶点x,y不属于同一个祖先节点,即两个集合之前没有建立联系,那么通过图中存在的向量关系, (w + dis[y] - dis[x] + k) % k 建立fax, fay联系;如果一开始同属于一个祖先结点,即一开始就有联系,那么判断新建立的关系和原关系是否冲突,用到dist函数。
 
如图:
 
void merge(int x, int y, int w) {
    int fax = find(x), fay = find(y);
    if(fax == fay) return ;
    fa[fax] = fay;
    dis[fax] = (w + dis[y] - dis[x] + k) % k;
}
 
 

4.dist( )函数

  判断两个节点是否存在关系,以及什么关系的函数。

  先find找各自祖先,如果重合,那么就是有关系,如果不重合,则无关,返回 -1。

  重合的情况下,二者已通过find路径压缩得到与共同祖先的关系,分别是dis[x],dis[y]。那么x和y之间的关系通过向量的关系可以得到 (dis[x] - dis[y] + k) % k

 

 
int dist(int x, int y) {
    int fax = find(x), fay = find(y);
    if (fax != fay) return -1;
    return (dis[x] - dis[y] + 3) % 3;
}

 

例题

  • 食物链

    题解:因为一共有三个物种,而且这三个物种的食物链关系形成了闭环,进而dis可以进行取模操作,得出该节点和根节点的食物链关系。
    复制代码
    #include<iostream>
    using namespace std;
    const int maxn = 5e4 + 20;
    int n, m;
    int tot;
    int fa[maxn], dis[maxn];
    
    inline int read()
    {
        int x=0,f=1;char ch=getchar();
        while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
        while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
        return x*f;
    }
    
    void init() {
        for(int i = 1; i <= n; i ++) {
            fa[i] = i;
        }
    }
    
    int find(int x) {
        if (fa[x] == x) return x;
        else {
            int ff = fa[x];
            fa[x] = find(fa[x]);//先保证前继节点的值更新好才能进行后续更新 
            dis[x] = (dis[x] + dis[ff]) % 3;
        }
        return fa[x];
    }
    
    
    void merge(int x, int y, int w) {
        int fax = find(x), fay = find(y);
        if (fax == fay) return ;
        fa[fax] = fay;
        dis[fax] = (dis[y] + w - dis[x] + 3) % 3;
    }
    
    
    int dist(int x, int y) {
        int fax = find(x), fay = find(y);
        if (fax != fay) return -1;
        return (dis[x] - dis[y] + 3) % 3;
    }
    
    int main() {
        n = read(), m = read();
        init();
        for (int i = 1; i <= m; i ++) {
            int op, x, y;
            op = read(), x = read(), y = read();
            if (x > n || y > n) {
                tot ++;
                continue;
            }
            int fax = find(x), fay = find(y);
            if (op == 1) {
                if (fax != fay) merge(x, y, 0);
                else if ((dist(x, y) != 0)) tot ++;
            } else {
                if (x == y) {
                    tot ++;
                    continue;
                }
                if (fax != fay) merge(x, y, 1);
            else if ((dist(x, y) != 1)) tot ++;
            }
        }
        cout << tot << "\n";
    
        return 0;
    }
    View Code
    复制代码
     

    关押罪犯

    复制代码
    #include<bits/stdc++.h>
    using namespace std;
    const int maxn = 1e5 + 20;
    int n, m;
    int fa[maxn];
    int dis[maxn];
    struct node {
        int a, b;
        int w;
    } t[maxn];
    
    bool cmp(node x, node y) {
        return x.w > y.w;
    }
    
    void init() {
        for (int i = 1; i <= n; i ++) {
            fa[i] = i;
            dis[i] = 0;//enemy
        }
    }
    
    int find(int x) {
        if (fa[x] == x) {
            return x;
        }
        
        
        int ff = fa[x];
        fa[x] = find(fa[x]);
        dis[x] = (dis[x] + dis[ff] + 2) % 2;
        return fa[x];
    }
    
    void merge(int x, int y, int w) {
        int fax = find(x), fay = find(y);
        if (fax == fay) return ;
        fa[fax] = fay;
        dis[fax] = (w + dis[y] - dis[x] + 2) % 2;
        
    }
    
    
    int dist(int x, int y) {
        int fax = find(x);
        int fay = find(y);
        
        if (fax != fay) return -1;
    //     cout << "x y " << x << " " << y << "\n";
        return (dis[x] - dis[y] + 2) % 2;
    }
    
    int main() {
        cin >> n >> m;
        init();
        for (int i = 1; i <= m; i ++) {
            cin >> t[i].a >> t[i].b >> t[i].w;
        }
        sort(t + 1, t + m + 1, cmp);
        
        for (int i = 1; i <= m; i ++) {
            int a = t[i].a, b = t[i].b, w = t[i].w;
            int tmp = dist(a, b);
            if (tmp == -1) {
                merge(a, b, 1);
            } else {
                if (tmp == 0) {
                    cout << w << "\n";
                    return 0;
                }
            }
        }
        cout << "0\n";
        return 0;
    }
    View Code
    复制代码

     

    AcWing 238. 银河英雄传说

复制代码
   
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 20;
int fa[maxn], dis[maxn], node[maxn];
int n;

void init() {
    for (int i = 1; i <= n; i ++) {
        fa[i] = i;
        dis[i] = 0;
        node[i] = 1;
    }
}

int find(int x) {
    if (fa[x] == x) return x;
    int ff = fa[x];
    fa[x] = find(fa[x]);
    dis[x] = dis[x] + dis[ff];
    return fa[x];
}

void merge(int x, int y, int w) {
    int fax = find(x);
    int fay = find(y);
    if (fax == fay) return;
    fa[fax] = fay;
    dis[fax] = w + dis[y] - dis[x];
    node[y] += node[x];
}

int dist(int x, int y) {
    int fax = find(x);
    int fay = find(y);
    if(x == y) return 1;
    if (fax != fay) return -1;
    return abs(dis[x] - dis[y]);
}

int main() {
    cin >> n;
    init();
    for (int i = 1; i <= n; i ++) {
        char c;
        int x, y;
        cin >> c >> x >> y;
        if (c == 'M') {
            int fax = find(x), fay = find(y);
            if (fax == fay) continue;
            merge(fax, fay, node[fay]);
        } else {
            int tt = dist(x, y);
            if (tt == -1) cout << "-1\n";
            else cout << tt - 1 << "\n";
        }

    }
    return  0;
}
View Code
复制代码

 

posted @   Y2ZH  阅读(128)  评论(3编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
点击右上角即可分享
微信分享提示