【学习笔记】最小生成树

提示:

  • 文中代码是按照洛谷题目 P3366【模板】最小生成树 编写的。
  • 讲的有可能不全部正确,请指出。
  • 伪代码并不标准,但能看。

MST 介绍

MST(最小生成树,全称 Minimum Spanning Tree)是指一张有向连通图中边权之和最小的一棵树。

最小生成树的构造目前其实有三种算法,常用的 Kruskal、Prim 和不常用(我没见过)的 Boruvka 算法(可以看 OI Wiki - 最小生成树 - Boruvka 算法)。Boruvka 算法因为不常见和没学过这里不讲了。

Prim

介绍

Prim 是一种以贪心为主要思想的 MST 算法。

具体过程如下:

  1. 建立一个点集,点集中先包含一个点。
  2. 找到距离点集最近的点,将其加入。
  3. 重复 2.,直到整张图都进入点集。

这样就可以找到一个 MST,这是一个找点的过程。

伪代码

while not 所有点进点集:
	node tmp
	for node X in 所有点集:
		if not X in 点集 and 到 tmp 距离 > 到 X 距离:
			tmp <- x
	tmp 进 点集

代码

#include <bits/stdc++.h>
using namespace std;

const int INF = 0x3f3f3f3f;
const int N = 5005;
int q[N][N], dis[N];
bool vis[N];
int n, m;
long long ans;

void prim(int x){
    dis[x] = 0;
    for (int i = 1; i <= n; i ++){
        int cur = -1;
        for (int j = 1; j <= n; j ++){
            if (!vis[j] && (cur == -1 || dis[j] < dis[cur])){
                cur = j;
            }
        }

        if (dis[cur] >= INF){
            ans = INF;
            return ;
        }
        
        ans += dis[cur];
        vis[cur] = 1;

        for (int k = 1; k <= n; k ++){
            if (!vis[k]) dis[k] = min(dis[k], q[cur][k]);
        }
    }
}

int main(){
    memset(q, 0x3f, sizeof(q));
    memset(dis, 0x3f, sizeof(dis));
    cin >> n >> m;

    for (int i = 1, u, v, w; i <= m; i ++){
        cin >> u >> v >> w;
        q[u][v] = min(q[u][v], w);
        q[v][u] = min(q[v][u], w);
    }

    prim(1);
    if (ans >= INF) puts("orz");
    else cout << ans;
}

Kruskal

介绍

Kruskal 也是一种贪心思想的算法,但与 Prim 的不同之处在于 Prim 是加点,而 Kruskal 是加边。其过程如下:

  1. 只看点集,不看边,认为每一个点孤立。
  2. 找到最小的边,将他连接的点加入集合。
  3. 继续,知道所有点加入集合。

可以证明,Kruskal 进行到最后一定是最小生成树。

伪代码

while not 所有点入集合:
	for edge X in 所有未使用边(边权从小到大):
		X.u, X.v -> 集合
		X 标记使用

代码

为了更方便,Kruskal 算法巧妙地运用了并查集,他通过记录每个节点对应的集合的根节点,记录不同的集合。

#include <bits/stdc++.h>
using namespace std;

const int N = 2e5 + 5;
struct node{
	int x, y, z;
} edge[N];
int fa[N];

bool cmp(node a, node b){
	return a.z < b.z;
}

int find(int x){
	return x == fa[x] ? x : fa[x] = find(fa[x]);
}

int main(){
	int n, m;
	cin >> n >> m;
	
	for (int i = 1; i <= m; i ++){
		cin >> edge[i].x >> edge[i].y >> edge[i].z;
	}
	
	sort(edge, edge + 1 + m, cmp);
	
	for (int i = 1; i <= n; i ++){
		fa[i] = i;
	}
	
	long long sum = 0;
	for (int i = 1; i <= m; i ++){
		int x = find(edge[i].x);
		int y = find(edge[i].y);
		if (x != y){
			fa[y] = x;
			sum += edge[i].z;
		}
	}
	
	int ans = 0;
	for (int i = 1; i <= n; i ++){
		if (i == find(i)) ans ++;
	}
	
	if (ans > 1) puts("orz");
	else cout << sum;
}
posted @ 2024-07-26 15:55  lym12_ovo  阅读(39)  评论(1编辑  收藏  举报