python应用:求最短路径(Dijkstra+堆优化)
以codewars中3 kyu Path Finder #3: the Alpinist 为例
题干:
You are at start location [0, 0]
in mountain area of NxN and you can only move in one of the four cardinal directions (i.e. North, East, South, West). Return minimal number of climb rounds
to target location [N-1, N-1]
. Number of climb rounds
between adjacent locations is defined as difference of location altitudes (ascending or descending).
代码主体部分(构建图(dict),堆优化的Dijkstra算法):
1 import heapq 2 3 4 def get_distance(a, b): 5 return abs(int(a)-int(b)) 6 7 def get_G(area): 8 G = {} 9 area = area.splitlines() 10 N = len(area) 11 for i in range(N): 12 for j in range(N): 13 n = N * i + j 14 G[n] = {} 15 if i < N-1: 16 G[n][n+N] = get_distance(area[i][j], area[i+1][j]) 17 if j < N-1: 18 G[n][n+1] = get_distance(area[i][j], area[i][j+1]) 19 if i > 0: 20 G[n][n-N] = get_distance(area[i][j], area[i-1][j]) 21 if j > 0: 22 G[n][n-1] = get_distance(area[i][j], area[i][j-1]) 23 return G 24 25 26 def dijkstra(G, start): 27 dis = dict((key, float('inf')) for key in G) # start 到每个点的距离 28 dis[start] = 0 29 vis = dict((key, False) for key in G) # 是否访问过某个点 30 31 ### 堆优化 32 pq = [] 33 heapq.heappush(pq, [dis[start], start]) 34 35 path = dict((key, [start]) for key in G) # 记录到每个点的路径 36 while len(pq) > 0: 37 v_dis, v = heapq.heappop(pq) # 未访问的点中距离最近的点 38 if vis[v] is True: 39 continue 40 vis[v] = True 41 p = path[v].copy() # 到v的最短路径 42 for node in G[v]: # 与v直接相连的点 43 new_dis = dis[v] + G[v][node] # 更新到下一个点的距离 44 if new_dis < dis[node] and (not vis[node]): # 比较距离是否更近,是则更新相关信息 45 dis[node] = new_dis 46 heapq.heappush(pq, [dis[node], node]) 47 temp = p.copy() 48 temp.append(node) 49 path[node] = temp 50 return dis, path
测试部分:
1 g = "\n".join([ 2 "000000", 3 "000000", 4 "000000", 5 "000010", 6 "000109", 7 "001010" 8 ]) 9 10 def path_finder(area): 11 dis, path = dijkstra(get_G(g), 0) 12 print(list(dis.values())[-1])
其他作者解法:
1 def path_finder(maze): 2 grid = maze.splitlines() 3 end = h, w = len(grid) - 1, len(grid[0]) - 1 4 bag, seen = {(0, 0): 0}, set() 5 while bag: 6 x, y = min(bag, key=bag.get) 7 rounds = bag.pop((x, y)) 8 seen.add((x, y)) 9 if (x, y) == end: return rounds 10 for u, v in (-1, 0), (0, 1), (1, 0), (0, -1): 11 m, n = x + u, y + v 12 if (m, n) in seen or not (0 <= m <= h and 0 <= n <= w): continue 13 new_rounds = rounds + abs(int(grid[x][y]) - int(grid[m][n])) 14 if new_rounds < bag.get((m, n), float('inf')): bag[m, n] = new_rounds