Java 双向BFS 算法实现
记录下实现的双向广度遍历(BFS)代码实现。
使用的结构复用Java 迪杰斯特拉 算法实现 里的结构。
public class BiBFS {
private final List<GraphNode> graph;
private final String startNodeId;
private final String endNodeId;
private final Queue<TraverseNode> forwardQueueNodes = new LinkedList<>();
private final List<TraverseNode> forwardVisitedNodes = new ArrayList<>();
private final Queue<TraverseNode> reverseQueueNodes = new LinkedList<>();
private final List<TraverseNode> reverseVisitedNodes = new ArrayList<>();
public BiBFS(List<GraphNode> graph, String startNodeId, String endNodeId){
this.graph = graph;
this.startNodeId = startNodeId;
this.endNodeId = endNodeId;
}
public List<String> calcShortPath() {
init();
return findPath();
}
private void init() {
forwardQueueNodes.clear();
forwardVisitedNodes.clear();
reverseQueueNodes.clear();
reverseVisitedNodes.clear();
}
private List<String> findPath() {
Map<String, GraphNode> graphNodeMap = graph.stream().collect(Collectors.toMap(GraphNode::getNodeId, node -> node));
GraphNode startGraphNode = graphNodeMap.get(startNodeId);
GraphNode endGraphNode = graphNodeMap.get(endNodeId);
TraverseNode starTraverseNode = convertGraphNode2TraverseNode(startGraphNode);
forwardQueueNodes.add(starTraverseNode);
TraverseNode endTraverseNode = convertGraphNode2TraverseNode(endGraphNode);
reverseQueueNodes.add(endTraverseNode);
List<TraverseNode> intersectNodes = forwardQueueNodes.stream().filter(reverseQueueNodes::contains).
collect(Collectors.toList());
while (intersectNodes.isEmpty()) {
traverseQueue(graphNodeMap, true);
forwardQueueNodes.stream().filter(reverseQueueNodes::contains).findFirst().ifPresent(intersectNode -> {
intersectNodes.add(intersectNode);
forwardVisitedNodes.add(intersectNode);
reverseQueueNodes.stream().filter(forwardQueueNodes::contains).findFirst().ifPresent(reverseVisitedNodes::add);
});
if(intersectNodes.isEmpty()){
traverseQueue(graphNodeMap, false);
}
}
TraverseNode intersectTraverseNode = intersectNodes.get(0);
List<String> forwardPath = tidyUpPath(intersectTraverseNode);
forwardPath.forEach(System.out::println);
return forwardPath;
}
private TraverseNode convertGraphNode2TraverseNode(GraphNode startGraphNode) {
TraverseNode traverseNode = new TraverseNode();
traverseNode.setNodeId(startGraphNode.getNodeId());
traverseNode.setWeight(0);
return traverseNode;
}
private void traverseQueue(Map<String, GraphNode> graphNodeMap, boolean isForward) {
Queue<TraverseNode> curQueue = isForward ? forwardQueueNodes : reverseQueueNodes;
List<TraverseNode> curVisits = isForward ? forwardVisitedNodes : reverseVisitedNodes;
TraverseNode curTraverseNode = curQueue.poll();
curVisits.add(curTraverseNode);
GraphNode curGraphNode = graphNodeMap.get(Objects.requireNonNull(curTraverseNode).getNodeId());
curGraphNode.getNearNodeValueTable().keySet().forEach(nodeId -> {
if (curQueue.stream().anyMatch(traverseNode -> Objects.equals(traverseNode.getNodeId(), nodeId)) ||
curVisits.stream().anyMatch(traverseNode -> Objects.equals(traverseNode.getNodeId(), nodeId))) {
return;
}
TraverseNode traverseNode = new TraverseNode();
traverseNode.setNodeId(nodeId);
traverseNode.setPreNodeId(curTraverseNode.getNodeId());
curQueue.add(traverseNode);
});
}
private List<String> tidyUpPath(TraverseNode intersectTraverseNode) {
List<String> forwardPath = new ArrayList<>();
List<String> reversePath = new ArrayList<>();
Tools.tidyUpPath(forwardPath, forwardVisitedNodes, intersectTraverseNode.getNodeId(), true);
Tools.tidyUpPath(reversePath, reverseVisitedNodes, intersectTraverseNode.getNodeId(), false);
for (int i = 1; i < reversePath.size(); i++) {
forwardPath.add(reversePath.get(i));
}
return forwardPath;
}
}