dijkstra最短路代码模板更新

 

 本文参考了C++ ACM的dijkstra模板:

 

 

 

先说之前自己写错了的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections import defaultdict
from heapq import heappush, heappop
 
 
def dijkstra(edges, start_node, end_node):
    graph = defaultdict(dict)
    for src, dst, distance in edges:
        graph[src][dst] = distance
 
    q = [(0, start_node, None)]
    found_min_dist_nodes = set()
    distances = {start_node: 0}
    back_paths = {}
 
    while q:
        cost, min_dist_node, src_node = heappop(q)
 
        #  下面这行代码非常关键,是为了去除优先级队列q里冗余的push,见后注释说明重复节点push问题
        #if min_dist_node in found_min_dist_nodes:
        #    continue
 
        found_min_dist_nodes.add(min_dist_node)
        back_paths[min_dist_node] = src_node
 
        if min_dist_node == end_node:
            return cost, back_paths
 
        for neibor_node, distance in graph[min_dist_node].items():
            if neibor_node in found_min_dist_nodes:
                continue
 
            prev_dist = distances.get(neibor_node, float('inf'))
            new_dist = cost + distance
            if new_dist < prev_dist:
                distances[neibor_node] = new_dist
                # 下面这个代码,正常情况下,应该是更新优先级队列里neibor_node的priority value
                # 但因为priority queue无原生更新api支持,所以下面代码是在没有remove neibor_node的情况直接push,会导致重复节点push
                heappush(q, (new_dist, neibor_node, min_dist_node))
 
    return float("inf"), back_paths
 
 
def find_path(back_paths, start_node, end_node):
    ans = [end_node]
    while end_node != start_node:
        end_node = back_paths[end_node]
        ans.append(end_node)
 
    return ans[::-1]
 
 
if __name__ == "__main__":
    edges = [
        ("A", "B", 5),
        ("A", "C", 10),
        ("C", "D", 10),
        ("B", "C", 2)]
 
    dist, backpaths = dijkstra(edges, "A", "D")
    print("dist: ", dist)
    print(find_path(backpaths, "A", "D"))

  

上面案例中的图示例:

A--------(5)--------->B

|                         /

(10)                / (2)

|                /

      C

      |(10)

     D

 

肉眼看,A->D的最短路距离是17,路径是ABCD。

 

如果没有下面的代码:

1
2
if min_dist_node in found_min_dist_nodes:
     continue<br><br>输出A到D的最短距离和路径为:

dist: 17
['A', 'C', 'D']

路径这个答案是错的!!!路径计算错了!但是距离计算是ok的!

为啥呢???因此C节点会重复push,debug下就可以看出来了:

 

 

 

因此,我们加上上述if判定进行去重,因为之前已经pop过了,再pop重复节点已经没有意义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections import defaultdict
from heapq import heappush, heappop
 
 
def dijkstra(edges, start_node, end_node):
    graph = defaultdict(dict)
    for src, dst, distance in edges:
        graph[src][dst] = distance
 
    q = [(0, start_node, None)]
    found_min_dist_nodes = set()
    distances = {start_node: 0}
    back_paths = {}
 
    while q:
        cost, min_dist_node, src_node = heappop(q)
 
        #  下面这行代码非常关键,是为了去除优先级队列q里冗余的push,见后注释说明重复节点push问题
        if min_dist_node in found_min_dist_nodes:
            continue
 
        found_min_dist_nodes.add(min_dist_node)
        back_paths[min_dist_node] = src_node
 
        if min_dist_node == end_node:
            return cost, back_paths
 
        for neibor_node, distance in graph[min_dist_node].items():
            if neibor_node in found_min_dist_nodes:
                continue
 
            prev_dist = distances.get(neibor_node, float('inf'))
            new_dist = cost + distance
            if new_dist < prev_dist:
                distances[neibor_node] = new_dist
                # 下面这个代码,正常情况下,应该是更新优先级队列里neibor_node的priority value
                # 但因为priority queue无原生更新api支持,所以下面代码是在没有remove neibor_node的情况直接push,会导致重复节点push
                heappush(q, (new_dist, neibor_node, min_dist_node))
 
    return float("inf"), back_paths
 
 
def find_path(back_paths, start_node, end_node):
    ans = [end_node]
    while end_node != start_node:
        end_node = back_paths[end_node]
        ans.append(end_node)
 
    return ans[::-1]
 
 
if __name__ == "__main__":
    edges = [
        ("A", "B", 5),
        ("A", "C", 10),
        ("C", "D", 10),
        ("B", "C", 2)]
 
    dist, backpaths = dijkstra(edges, "A", "D")
    print("dist: ", dist)
    print(find_path(backpaths, "A", "D"))

  

  输出:

dist: 17
['A', 'B', 'C', 'D']

这下就对了!!!

 

因此对于dijkstra最短路代码,还是要加上if判定!

 

posted @   bonelee  阅读(59)  评论(3编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2016-12-17 318. Maximum Product of Word Lengths ——本质:英文单词中字符是否出现可以用26bit的整数表示
点击右上角即可分享
微信分享提示