Java 迪杰斯特拉 算法实现
在这里记录下自己写的迪杰斯特拉代码。
思路
本质是贪心算法:
- 开始时设定两个集合:S,T;S存入已经遍历的点,T存所有未遍历的点;
- 首先将起点放入S中,更新T中所有节点的权重(和起点联通的节点更新权重,其他节点权重设为无穷大);
- 在T中寻找权重最低的点(假设是M点),将M点放入S中,同时更新T里所有节点的权重(判断是否要进行松弛操作,即节点的权重是否大于M点权重+M到此节点的权重,如果是,那将权重更新为M点权重+M到此节点权重);
- 重复上一步,直到从T里取出的是终点,结束;
代码
首先是定义了节点结构、用于遍历的节点结构和约束条件;
节点结构
import lombok.Getter;
import lombok.Setter;
import java.util.HashMap;
import java.util.Map;
/**
* 图节点
*
* @author ljmine on 2023/7/25
**/
@Getter
@Setter
public class GraphNode {
//节点id
String nodeId;
// 与此节点相连的节点以及权重,key:相连节点id,value:此节点到相连节点的权重
// 这里用Map的方式模拟了图里的链路
Map<String, Integer> nearNodeValueTable = new HashMap<>();
public GraphNode copy() {
GraphNode graphNode = new GraphNode();
graphNode.setNodeId(this.getNodeId());
Map<String, Integer> copyMap = new HashMap<>(this.nearNodeValueTable);
graphNode.setNearNodeValueTable(copyMap);
return graphNode;
}
}
用于遍历用的节点结构
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
/**
* 遍历节点
*
* @author ljmine on 2023/7/25
**/
@Getter
@Setter
@EqualsAndHashCode(of = "nodeId")
public class TraverseNode {
// 记录到达此节点的前一节点,用于最后回溯路径
String preNodeId;
String nodeId;
// 权重,默认为int最大值
Integer weight = Integer.MAX_VALUE;
}
约束条件
import javafx.util.Pair;
import lombok.Getter;
import lombok.Setter;
import java.util.ArrayList;
import java.util.List;
/**
* 算路约束
*
* @author ljmine on 2023/7/25
**/
@Getter
@Setter
public class Constraint {
List<String> excludeNode = new ArrayList<>();
List<Pair<String, String>> excludeLink;
List<String> includeNode = new ArrayList<>();
}
输出结果结构
import javafx.util.Pair;
import lombok.Getter;
import lombok.Setter;
import java.util.ArrayList;
import java.util.List;
@Getter
@Setter
public class CalcResult {
// 输出路径
List<String> resultPath = new ArrayList<>();
// 此路径总权重
Integer weight = Integer.MAX_VALUE;
// 算多条路径时使用
List<Pair<List<String>,Integer>> otherPaths = new ArrayList<>();
}
功能实现
以下是实现代码,可以实现必排节点约束和顺序必经约束,对于无序必经约束,迪杰斯特拉只能暴力穷举,不太实用,这里就不写了:
import com.example.CalcResult;
import com.example.Constraint;
import com.example.GraphNode;
import com.example.TraverseNode;
import javafx.util.Pair;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* 迪杰斯特拉算法
*
* @author ljmine on 2023/7/24
**/
public class Dijkstra {
private final List<GraphNode> graph;
private final String startNodeId;
private final String endNodeId;
private final Constraint constraint;
private final List<TraverseNode> traversedNodes = new ArrayList<>();
private final List<TraverseNode> noTraverseNodes = new ArrayList<>();
public Dijkstra(List<GraphNode> graph, String startNodeId, String endNodeId, Constraint constraint) {
List<GraphNode> copyGraph = new ArrayList<>();
graph.forEach(graphNode -> {
GraphNode copy = graphNode.copy();
copyGraph.add(copy);
});
this.graph = copyGraph;
this.startNodeId = startNodeId;
this.endNodeId = endNodeId;
this.constraint = constraint;
}
public CalcResult calcShortPath(boolean needPrint) {
init();
CalcResult result = findPath();
if(needPrint){
result.getResultPath().forEach(System.out::println);
System.out.println("weight:" + result.getWeight());
}
return result;
}
/**
* 初始化,构造S,T两个集合
*/
private void init() {
noTraverseNodes.clear();
traversedNodes.clear();
if (constraint != null) {
// 有必排约束
deleteNodeByExcludeConstraint();
deleteNearNodeByExcludeConstraint();
}
graph.forEach(graphNode -> {
TraverseNode traverseNode = new TraverseNode();
traverseNode.setNodeId(graphNode.getNodeId());
traverseNode.setWeight(Integer.MAX_VALUE);
if (startNodeId.equals(traverseNode.getNodeId())) {
traverseNode.setWeight(0);
}
noTraverseNodes.add(traverseNode);
});
}
private CalcResult findPath() {
if (constraint != null) {
// 必经约束,这里要求必经必须有序,无序必经迪杰斯特拉处理不好,需要遗传算法解决
CalcResult result = segmentDijkstraByIncludeConstraint();
if (result != null) {
return result;
}
}
while (!noTraverseNodes.isEmpty()) {
findNextNode();
}
List<String> path = new ArrayList<>();
tidyUpPath(path, traversedNodes, endNodeId, true);
CalcResult calcResult = new CalcResult();
calcResult.setResultPath(path);
traversedNodes.stream().filter(node ->
Objects.equals(node.getNodeId(), endNodeId)).findFirst().ifPresent(endNode ->
calcResult.setWeight(endNode.getWeight()));
return calcResult;
}
private CalcResult segmentDijkstraByIncludeConstraint() {
List<String> includeNodeIds = constraint.getIncludeNode();
if (includeNodeIds != null && !includeNodeIds.isEmpty()) {
CalcResult calcResult = new CalcResult();
includeNodeIds.add(0, startNodeId);
includeNodeIds.add(endNodeId);
List<String> path = new ArrayList<>();
Integer weight = 0;
CalcResult segmentResult;
Constraint segmentConstraint = new Constraint();
for (int i = 0; i < includeNodeIds.size() - 1; i++) {
Dijkstra dijkstra = new Dijkstra(this.graph, includeNodeIds.get(i), includeNodeIds.get(i + 1), segmentConstraint);
segmentResult = dijkstra.calcShortPath(false);
path.addAll(segmentResult.getResultPath());
weight += segmentResult.getWeight();
// 经过的点作为必排进行下次算路,否则可能出现回环
List<String> excludeNodes = path.subList(0, path.size() - 1);
segmentConstraint.getExcludeNode().addAll(excludeNodes);
}
path = path.stream().distinct().collect(Collectors.toList());
calcResult.setResultPath(path);
calcResult.setWeight(weight);
return calcResult;
}
return null;
}
private void deleteNodeByExcludeConstraint() {
List<String> excludeNodeIds = constraint.getExcludeNode();
if (excludeNodeIds != null && !excludeNodeIds.isEmpty()) {
List<GraphNode> excludeNode = this.graph.stream().filter(graphNode ->
excludeNodeIds.contains(graphNode.getNodeId())).collect(Collectors.toList());
for (GraphNode graphNode : excludeNode) {
this.graph.remove(graphNode);
graphNode.getNearNodeValueTable().keySet().forEach(nodeId -> this.graph.stream().filter(node ->
Objects.equals(nodeId, node.getNodeId())).findFirst().ifPresent(nearNode ->
nearNode.getNearNodeValueTable().remove(graphNode.getNodeId())));
}
}
}
private void deleteNearNodeByExcludeConstraint() {
List<Pair<String, String>> excludeLinkIds = constraint.getExcludeLink();
if (excludeLinkIds != null && !excludeLinkIds.isEmpty()) {
for (Pair<String, String> excludeLinkId : excludeLinkIds) {
String firstGraphNodeId = excludeLinkId.getKey();
String secondGraphNodeId = excludeLinkId.getValue();
this.graph.stream().filter(node ->
Objects.equals(node.getNodeId(), firstGraphNodeId)).findFirst().ifPresent(firstNode ->
firstNode.getNearNodeValueTable().remove(secondGraphNodeId));
this.graph.stream().filter(node ->
Objects.equals(node.getNodeId(), secondGraphNodeId)).findFirst().ifPresent(secondNode ->
secondNode.getNearNodeValueTable().remove(firstGraphNodeId));
}
}
}
/**
* 循环此方法,每次从T里取权重最低的点,放入S中,并更新T里的权重
*/
private void findNextNode() {
noTraverseNodes.stream().min(Comparator.comparing(TraverseNode::getWeight)).ifPresent(minNode -> {
Integer curWeight = minNode.getWeight();
GraphNode curNode = graph.stream().filter(node ->
Objects.equals(node.getNodeId(), minNode.getNodeId())).findFirst().orElse(null);
if (curNode == null) {
System.out.println("error! node is null");
return;
}
Map<String, Integer> nearNodes = curNode.getNearNodeValueTable();
noTraverseNodes.forEach(noTraverseNode -> {
String nodeId = noTraverseNode.getNodeId();
if (nearNodes.containsKey(nodeId)) {
if (noTraverseNode.getWeight() > curWeight + nearNodes.get(nodeId)) {
noTraverseNode.setWeight(curWeight + nearNodes.get(nodeId));
noTraverseNode.setPreNodeId(curNode.getNodeId());
}
}
});
noTraverseNodes.remove(minNode);
traversedNodes.add(minNode);
});
}
public static void tidyUpPath(List<String> path, List<TraverseNode> traversedNodes, String endNodeId, boolean needRev) {
if (endNodeId == null) {
if (needRev) {
Collections.reverse(path);
}
return;
}
TraverseNode curNode = traversedNodes.stream().filter(node ->
endNodeId.equals(node.getNodeId())).findFirst().orElse(null);
if (curNode != null) {
path.add(endNodeId);
tidyUpPath(path, traversedNodes, curNode.getPreNodeId(), needRev);
}
}
}