python | 算法-网络延迟时间-dijikstra算法应用
写在前面:
我自己用python练习算法与数据结构的典型算法汇总在这里:汇总-算法与数据结构-python版,欢迎翻阅!
2️⃣ 所用例子 :
代码详情
# leetcode 743题, 用这道题来练习Dijikstra算法
# 参考: https://github.com/algorithmzuo/algorithmbasic2020/blob/master/src/class16/Code06_NetworkDelayTime.java
# python中优先级队列的实现参考:https://geek-docs.com/python/python-examples/python-priority-queue.html
from queue import PriorityQueue
class NetworkDelayTime:
# 方法1:普通堆+屏蔽已经计算过的点
def net_delay_time1(self, times, n, k):
# 记录每个节点指向的节点和对应的延迟
nexts = {}
for record in times:
nexts[record[0]] = []
for record in times:
nexts[record[0]].append([record[1], record[2]])
# 采用优先级队列来充当小顶堆的角色
# 有关queue.PriorityQueue的知识,
# 参考官方文档:https://docs.python.org/zh-cn/3.7/library/queue.html?highlight=priorityqueue#queue.PriorityQueue
heap = PriorityQueue()
heap.put((0, [k, 0]))
# used 记录已经被计算过的节点
used = []
# result 记录所有最短距离的最大值,即题目所求
result = 0
while not heap.empty() and len(used) < n:
item = heap.get()
cur = item[1][0]
delay = item[1][1]
if cur in used: continue
used.append(cur)
result = max(delay, result)
if cur in nexts.keys():
for next in nexts[cur]:
new_delay = delay + next[1]
heap.put((new_delay, [next[0], new_delay]))
return -1 if len(used) < n else result
# 方法2:加强堆的解法
def net_delay_time2(self, times, n, k):
nexts = {}
for i in times:
nexts[i[0]] = []
for i in times:
nexts[i[0]].append([i[1], i[2]])
heap = Heap()
heap.add(k, 0)
num = 0
max_delay = 0
while not heap.empty():
out = heap.pop()
node = out[0]
delay = out[1]
num += 1
max_delay = max(max_delay, delay)
if node in nexts.keys():
for next_record in nexts[node]:
next_node = next_record[0]
next_delay = next_record[1]
heap.add(next_node, delay + next_delay)
return -1 if num < n else max_delay
class Heap:
def __init__(self):
self.heap = []
self.index = {}
self.used = []
self.size = 0
def empty(self):
return self.size == 0
def add(self, node, delay):
if node in self.used: return
if node not in self.index.keys():
self.index[node] = self.size
self.heap.append([node, delay])
self.heap_insert(self.size)
self.size += 1
def heap_insert(self, index):
parent = int((index - 1) / 2)
while self.heap[index][1] < self.heap[parent][1]:
self.swap(index, parent)
index = parent
parent = int((index - 1) / 2)
def swap(self, index1, index2):
record1 = self.heap[index1]
record2 = self.heap[index2]
self.heap[index1] = record2
self.heap[index2] = record1
self.index[record1[0]] = index2
self.index[record2[0]] = index1
def pop(self):
out = self.heap[0]
self.size -= 1
self.swap(0, self.size)
self.heap.pop()
self.heapify(0)
return out
def heapify(self, index):
left = index * 2 + 1
while left < self.size:
right = left + 1
smallest = right if right < self.size and self.heap[right][1] < self.heap[left][1] \
else left
smallest = left if self.heap[left][1] < self.heap[index][1] else index
if index == smallest: break
self.swap(index, smallest)
index = smallest
left = index * 2 + 1
# 测试
times1 = [[2, 1, 1], [2, 3, 1], [3, 4, 1]]
n1, k1 = 4, 2
times2 = [[1, 2, 1]]
n2, k2 = 2, 1
times3 = [[1, 2, 1]]
n3, k3 = 2, 2
solution = NetworkDelayTime()
# 测试方法1
result1 = solution.net_delay_time1(times1, n1, k1)
result2 = solution.net_delay_time1(times2, n2, k2)
result3 = solution.net_delay_time1(times3, n3, k3)
print(result1==2 and result2==1 and result3==-1)
# True
# 测试方法2
result1 = solution.net_delay_time2(times1, n1, k1)
result2 = solution.net_delay_time2(times2, n2, k2)
result3 = solution.net_delay_time2(times3, n3, k3)
print(result1==2 and result2==1 and result3==-1)
# True