import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.Stack;
/**
* 最小生成树算法:K算法
*
* 假设连通网G=(V,E),令最小生成树的初始状态为只有n个顶点而无边的非连通图T=(V,{}),
* 图中每个顶点自成一个连通分量。在E中选择代价最小的边,若该边依附的顶点分别在T中不同的连通分量上,则将此边加入到T中;
* 否则,舍去此边而选择下一条代价最小的边。
* 依此类推,直至T中所有顶点构成一个连通分量为止
*/
public class Kruskal {
public static Set<Edge> kruskalMST(Graph graph) {
UnionFind unionFind = new UnionFind();
unionFind.makeSets(graph.nodes.values());
PriorityQueue<Edge> priorityQueue = new PriorityQueue<>(new EdgeComparator());
priorityQueue.addAll(graph.edges);
Set<Edge> result = new HashSet<>();
while (!priorityQueue.isEmpty()) { // M 条边
Edge edge = priorityQueue.poll(); // O(logM)
if (!unionFind.isSameSet(edge.from, edge.to)) { // O(1)
result.add(edge);
unionFind.union(edge.from, edge.to);
}
}
return result;
}
// Union-Find Set
public static class UnionFind {
// key 某一个节点, value key节点往上的节点
private final HashMap<Node, Node> fatherMap;
// key 某一个集合的代表节点, value key所在集合的节点个数
private final HashMap<Node, Integer> sizeMap;
public UnionFind() {
fatherMap = new HashMap<>();
sizeMap = new HashMap<>();
}
public void makeSets(Collection<Node> nodes) {
fatherMap.clear();
sizeMap.clear();
for (Node node : nodes) {
fatherMap.put(node, node);
sizeMap.put(node, 1);
}
}
private Node findFather(Node n) {
Stack<Node> path = new Stack<>();
while (n != fatherMap.get(n)) {
path.add(n);
n = fatherMap.get(n);
}
while (!path.isEmpty()) {
fatherMap.put(path.pop(), n);
}
return n;
}
public boolean isSameSet(Node a, Node b) {
return findFather(a) == findFather(b);
}
public void union(Node a, Node b) {
if (a == null || b == null) {
return;
}
Node aParent = findFather(a);
Node bParent = findFather(b);
if (aParent != bParent) {
int aSetSize = sizeMap.get(aParent);
int bSetSize = sizeMap.get(bParent);
if (aSetSize <= bSetSize) {
fatherMap.put(aParent, bParent);
sizeMap.put(bParent, aSetSize + bSetSize);
sizeMap.remove(aParent);
} else {
fatherMap.put(bParent, aParent);
sizeMap.put(aParent, aSetSize + bSetSize);
sizeMap.remove(bParent);
}
}
}
}
public static class EdgeComparator implements Comparator<Edge> {
@Override
public int compare(Edge o1, Edge o2) {
return o1.weight - o2.weight;
}
}
class Graph {
public HashMap<Integer, Node> nodes;
public HashSet<Edge> edges;
public Graph() {
nodes = new HashMap<>();
edges = new HashSet<>();
}
}
class Node {
public int value;
public int in;
public int out;
public ArrayList<Node> nexts;
public ArrayList<Edge> edges;
public Node(int value) {
this.value = value;
nexts = new ArrayList<>();
edges = new ArrayList<>();
}
}
class Edge {
// 权重
public int weight;
public Node from;
public Node to;
public Edge(int weight, Node from, Node to) {
this.weight = weight;
this.from = from;
this.to = to;
}
}
}
/* 如有意见或建议,欢迎评论区留言;如发现代码有误,欢迎批评指正 */