Loading

Java 迪杰斯特拉 算法实现

在这里记录下自己写的迪杰斯特拉代码。

思路

本质是贪心算法:

  • 开始时设定两个集合:S,T;S存入已经遍历的点,T存所有未遍历的点;
  • 首先将起点放入S中,更新T中所有节点的权重(和起点联通的节点更新权重,其他节点权重设为无穷大);
  • 在T中寻找权重最低的点(假设是M点),将M点放入S中,同时更新T里所有节点的权重(判断是否要进行松弛操作,即节点的权重是否大于M点权重+M到此节点的权重,如果是,那将权重更新为M点权重+M到此节点权重);
  • 重复上一步,直到从T里取出的是终点,结束;

代码

首先是定义了节点结构、用于遍历的节点结构和约束条件;

节点结构

import lombok.Getter;
import lombok.Setter;

import java.util.HashMap;
import java.util.Map;

/**
 * 图节点
 *
 * @author ljmine on 2023/7/25
 **/
@Getter
@Setter
public class GraphNode {
    //节点id
    String nodeId;
    // 与此节点相连的节点以及权重,key:相连节点id,value:此节点到相连节点的权重
    // 这里用Map的方式模拟了图里的链路
    Map<String, Integer> nearNodeValueTable = new HashMap<>();

    public GraphNode copy() {
        GraphNode graphNode = new GraphNode();
        graphNode.setNodeId(this.getNodeId());
        Map<String, Integer> copyMap = new HashMap<>(this.nearNodeValueTable);
        graphNode.setNearNodeValueTable(copyMap);
        return graphNode;
    }
}

用于遍历用的节点结构

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;

/**
 * 遍历节点
 *
 * @author ljmine on 2023/7/25
 **/
@Getter
@Setter
@EqualsAndHashCode(of = "nodeId")
public class TraverseNode {
    // 记录到达此节点的前一节点,用于最后回溯路径
    String preNodeId;

    String nodeId;
    // 权重,默认为int最大值
    Integer weight = Integer.MAX_VALUE;
}

约束条件

import javafx.util.Pair;
import lombok.Getter;
import lombok.Setter;

import java.util.ArrayList;
import java.util.List;

/**
 * 算路约束
 *
 * @author ljmine on 2023/7/25
 **/
@Getter
@Setter
public class Constraint {

    List<String> excludeNode = new ArrayList<>();

    List<Pair<String, String>> excludeLink;

    List<String> includeNode = new ArrayList<>();
}

输出结果结构

import javafx.util.Pair;
import lombok.Getter;
import lombok.Setter;

import java.util.ArrayList;
import java.util.List;

@Getter
@Setter
public class CalcResult {
    // 输出路径
    List<String> resultPath = new ArrayList<>();
    // 此路径总权重
    Integer weight = Integer.MAX_VALUE;
    // 算多条路径时使用
    List<Pair<List<String>,Integer>> otherPaths = new ArrayList<>();
}

功能实现

以下是实现代码,可以实现必排节点约束和顺序必经约束,对于无序必经约束,迪杰斯特拉只能暴力穷举,不太实用,这里就不写了:

import com.example.CalcResult;
import com.example.Constraint;
import com.example.GraphNode;
import com.example.TraverseNode;
import javafx.util.Pair;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * 迪杰斯特拉算法
 *
 * @author ljmine on 2023/7/24
 **/
public class Dijkstra {

    private final List<GraphNode> graph;

    private final String startNodeId;

    private final String endNodeId;

    private final Constraint constraint;
    private final List<TraverseNode> traversedNodes = new ArrayList<>();

    private final List<TraverseNode> noTraverseNodes = new ArrayList<>();

    public Dijkstra(List<GraphNode> graph, String startNodeId, String endNodeId, Constraint constraint) {
        List<GraphNode> copyGraph = new ArrayList<>();
        graph.forEach(graphNode -> {
            GraphNode copy = graphNode.copy();
            copyGraph.add(copy);
        });
        this.graph = copyGraph;
        this.startNodeId = startNodeId;
        this.endNodeId = endNodeId;
        this.constraint = constraint;
    }

    public CalcResult calcShortPath(boolean needPrint) {
        init();
        CalcResult result = findPath();
        if(needPrint){
            result.getResultPath().forEach(System.out::println);
            System.out.println("weight:" + result.getWeight());
        }
        return result;
    }

    /**
     * 初始化,构造S,T两个集合
     */
    private void init() {
        noTraverseNodes.clear();
        traversedNodes.clear();
        if (constraint != null) {
            // 有必排约束
            deleteNodeByExcludeConstraint();
            deleteNearNodeByExcludeConstraint();
        }
        graph.forEach(graphNode -> {
            TraverseNode traverseNode = new TraverseNode();
            traverseNode.setNodeId(graphNode.getNodeId());
            traverseNode.setWeight(Integer.MAX_VALUE);
            if (startNodeId.equals(traverseNode.getNodeId())) {
                traverseNode.setWeight(0);
            }
            noTraverseNodes.add(traverseNode);
        });
    }

    private CalcResult findPath() {
        if (constraint != null) {
            // 必经约束,这里要求必经必须有序,无序必经迪杰斯特拉处理不好,需要遗传算法解决
            CalcResult result = segmentDijkstraByIncludeConstraint();
            if (result != null) {
                return result;
            }
        }
        while (!noTraverseNodes.isEmpty()) {
            findNextNode();
        }
        List<String> path = new ArrayList<>();
        tidyUpPath(path, traversedNodes, endNodeId, true);

        CalcResult calcResult = new CalcResult();
        calcResult.setResultPath(path);
        traversedNodes.stream().filter(node ->
                Objects.equals(node.getNodeId(), endNodeId)).findFirst().ifPresent(endNode ->
                calcResult.setWeight(endNode.getWeight()));

        return calcResult;
    }

    private CalcResult segmentDijkstraByIncludeConstraint() {
        List<String> includeNodeIds = constraint.getIncludeNode();
        if (includeNodeIds != null && !includeNodeIds.isEmpty()) {
            CalcResult calcResult = new CalcResult();
            includeNodeIds.add(0, startNodeId);
            includeNodeIds.add(endNodeId);
            List<String> path = new ArrayList<>();
            Integer weight = 0;
            CalcResult segmentResult;
            Constraint segmentConstraint = new Constraint();
            for (int i = 0; i < includeNodeIds.size() - 1; i++) {
                Dijkstra dijkstra = new Dijkstra(this.graph, includeNodeIds.get(i), includeNodeIds.get(i + 1), segmentConstraint);
                segmentResult = dijkstra.calcShortPath(false);
                path.addAll(segmentResult.getResultPath());
                weight += segmentResult.getWeight();
                // 经过的点作为必排进行下次算路,否则可能出现回环
                List<String> excludeNodes = path.subList(0, path.size() - 1);
                segmentConstraint.getExcludeNode().addAll(excludeNodes);
            }
            path = path.stream().distinct().collect(Collectors.toList());
            calcResult.setResultPath(path);
            calcResult.setWeight(weight);
            return calcResult;
        }
        return null;
    }

    private void deleteNodeByExcludeConstraint() {
        List<String> excludeNodeIds = constraint.getExcludeNode();
        if (excludeNodeIds != null && !excludeNodeIds.isEmpty()) {
            List<GraphNode> excludeNode = this.graph.stream().filter(graphNode ->
                    excludeNodeIds.contains(graphNode.getNodeId())).collect(Collectors.toList());
            for (GraphNode graphNode : excludeNode) {
                this.graph.remove(graphNode);
                graphNode.getNearNodeValueTable().keySet().forEach(nodeId -> this.graph.stream().filter(node ->
                        Objects.equals(nodeId, node.getNodeId())).findFirst().ifPresent(nearNode ->
                        nearNode.getNearNodeValueTable().remove(graphNode.getNodeId())));
            }
        }
    }

    private void deleteNearNodeByExcludeConstraint() {
        List<Pair<String, String>> excludeLinkIds = constraint.getExcludeLink();
        if (excludeLinkIds != null && !excludeLinkIds.isEmpty()) {
            for (Pair<String, String> excludeLinkId : excludeLinkIds) {
                String firstGraphNodeId = excludeLinkId.getKey();
                String secondGraphNodeId = excludeLinkId.getValue();
                this.graph.stream().filter(node ->
                        Objects.equals(node.getNodeId(), firstGraphNodeId)).findFirst().ifPresent(firstNode ->
                        firstNode.getNearNodeValueTable().remove(secondGraphNodeId));
                this.graph.stream().filter(node ->
                        Objects.equals(node.getNodeId(), secondGraphNodeId)).findFirst().ifPresent(secondNode ->
                        secondNode.getNearNodeValueTable().remove(firstGraphNodeId));
            }
        }
    }

    /**
     * 循环此方法,每次从T里取权重最低的点,放入S中,并更新T里的权重
     */
    private void findNextNode() {
        noTraverseNodes.stream().min(Comparator.comparing(TraverseNode::getWeight)).ifPresent(minNode -> {
            Integer curWeight = minNode.getWeight();
            GraphNode curNode = graph.stream().filter(node ->
                    Objects.equals(node.getNodeId(), minNode.getNodeId())).findFirst().orElse(null);
            if (curNode == null) {
                System.out.println("error! node is null");
                return;
            }
            Map<String, Integer> nearNodes = curNode.getNearNodeValueTable();
            noTraverseNodes.forEach(noTraverseNode -> {
                String nodeId = noTraverseNode.getNodeId();
                if (nearNodes.containsKey(nodeId)) {
                    if (noTraverseNode.getWeight() > curWeight + nearNodes.get(nodeId)) {
                        noTraverseNode.setWeight(curWeight + nearNodes.get(nodeId));
                        noTraverseNode.setPreNodeId(curNode.getNodeId());
                    }
                }
            });
            noTraverseNodes.remove(minNode);
            traversedNodes.add(minNode);
        });
    }
	
    public static void tidyUpPath(List<String> path, List<TraverseNode> traversedNodes, String endNodeId, boolean needRev) {
        if (endNodeId == null) {
            if (needRev) {
                Collections.reverse(path);
            }
            return;
        }
        TraverseNode curNode = traversedNodes.stream().filter(node ->
                endNodeId.equals(node.getNodeId())).findFirst().orElse(null);
        if (curNode != null) {
            path.add(endNodeId);
            tidyUpPath(path, traversedNodes, curNode.getPreNodeId(), needRev);
        }
    }
}

posted @ 2023-09-01 15:49  星流残阳  阅读(254)  评论(0编辑  收藏  举报