【学习笔记】最小生成树
提示:
- 文中代码是按照洛谷题目 P3366【模板】最小生成树 编写的。
- 讲的有可能不全部正确,请指出。
- 伪代码并不标准,但能看。
MST 介绍
MST(最小生成树,全称 Minimum Spanning Tree)是指一张有向连通图中边权之和最小的一棵树。
最小生成树的构造目前其实有三种算法,常用的 Kruskal、Prim 和不常用(我没见过)的 Boruvka 算法(可以看 OI Wiki - 最小生成树 - Boruvka 算法)。Boruvka 算法因为不常见和没学过这里不讲了。
Prim
介绍
Prim 是一种以贪心为主要思想的 MST 算法。
具体过程如下:
- 建立一个点集,点集中先包含一个点。
- 找到距离点集最近的点,将其加入。
- 重复 2.,直到整张图都进入点集。
这样就可以找到一个 MST,这是一个找点的过程。
伪代码
while not 所有点进点集:
node tmp
for node X in 所有点集:
if not X in 点集 and 到 tmp 距离 > 到 X 距离:
tmp <- x
tmp 进 点集
代码
#include <bits/stdc++.h>
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 5005;
int q[N][N], dis[N];
bool vis[N];
int n, m;
long long ans;
void prim(int x){
dis[x] = 0;
for (int i = 1; i <= n; i ++){
int cur = -1;
for (int j = 1; j <= n; j ++){
if (!vis[j] && (cur == -1 || dis[j] < dis[cur])){
cur = j;
}
}
if (dis[cur] >= INF){
ans = INF;
return ;
}
ans += dis[cur];
vis[cur] = 1;
for (int k = 1; k <= n; k ++){
if (!vis[k]) dis[k] = min(dis[k], q[cur][k]);
}
}
}
int main(){
memset(q, 0x3f, sizeof(q));
memset(dis, 0x3f, sizeof(dis));
cin >> n >> m;
for (int i = 1, u, v, w; i <= m; i ++){
cin >> u >> v >> w;
q[u][v] = min(q[u][v], w);
q[v][u] = min(q[v][u], w);
}
prim(1);
if (ans >= INF) puts("orz");
else cout << ans;
}
Kruskal
介绍
Kruskal 也是一种贪心思想的算法,但与 Prim 的不同之处在于 Prim 是加点,而 Kruskal 是加边。其过程如下:
- 只看点集,不看边,认为每一个点孤立。
- 找到最小的边,将他连接的点加入集合。
- 继续,知道所有点加入集合。
可以证明,Kruskal 进行到最后一定是最小生成树。
伪代码
while not 所有点入集合:
for edge X in 所有未使用边(边权从小到大):
X.u, X.v -> 集合
X 标记使用
代码
为了更方便,Kruskal 算法巧妙地运用了并查集,他通过记录每个节点对应的集合的根节点,记录不同的集合。
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
struct node{
int x, y, z;
} edge[N];
int fa[N];
bool cmp(node a, node b){
return a.z < b.z;
}
int find(int x){
return x == fa[x] ? x : fa[x] = find(fa[x]);
}
int main(){
int n, m;
cin >> n >> m;
for (int i = 1; i <= m; i ++){
cin >> edge[i].x >> edge[i].y >> edge[i].z;
}
sort(edge, edge + 1 + m, cmp);
for (int i = 1; i <= n; i ++){
fa[i] = i;
}
long long sum = 0;
for (int i = 1; i <= m; i ++){
int x = find(edge[i].x);
int y = find(edge[i].y);
if (x != y){
fa[y] = x;
sum += edge[i].z;
}
}
int ans = 0;
for (int i = 1; i <= n; i ++){
if (i == find(i)) ans ++;
}
if (ans > 1) puts("orz");
else cout << sum;
}
答应我,不要乱转载lym12_ovo的https://www.cnblogs.com/lym12/p/18325453/mst_study,好不好?