考虑一个有权图,一个源点,dj算法可以求出图中的所有点到源点的最短路径。
为了实现dj算法,我们需要维护3个数据结构,一个用于记录未确定到源点最短路径的点集upss,一个用于记录源点到该点的最短路径长度dist,一个用于记录最短路径中该点的上一个点prev。
此处以leetcode 743.网络延迟时间为例,源点为2,初始数据结构如下。
| upss | dist | prev |
|---|---|---|
| 1 | inf | -1 |
| 2 | 0 | -1 |
| 3 | inf | -1 |
| 4 | inf | -1 |
初始当前节点为2.我们从upss中找到当前节点的一步可达节点1、3,如果经过当前节点到达1、3的距离小于现有源点到1、3的距离,就更新dist和prev,并将2从upss移除。
| upss | dist | prev |
|---|---|---|
| 1 | 1 | 2 |
| 0 | -1 | |
| 3 | 1 | 2 |
| 4 | inf | -1 |
找到upss中目前距离源点最近的节点1作为当前节点,从upss中找到当前节点的一步可达节点,没有找到,将1从upss移除。
| upss | dist | prev |
|---|---|---|
| 1 | 2 | |
| 0 | -1 | |
| 3 | 1 | 2 |
| 4 | inf | -1 |
找到upss中目前距离源点最近的节点3作为当前节点,从upss中找到当前节点的一步可达节点4,如果经过当前节点到达4的距离小于现有源点到4的距离,就更新dist和prev,并将3从upss移除。
| upss | dist | prev |
|---|---|---|
| 1 | 2 | |
| 0 | -1 | |
| 1 | 2 | |
| 4 | 2 | 3 |
找到upss中目前距离源点最近的节点4作为当前节点,从upss中找到当前节点的一步可达节点,没有找到,将4从upss移除。
| upss | dist | prev |
|---|---|---|
| 1 | 2 | |
| 0 | -1 | |
| 1 | 2 | |
| 2 | 3 |
upss为空,循环终止。dist记录了对应点到源点的最短路径长度,prev记录了最短路径经过的点。
- 2到1的最短路径为:2->1,长度为1。
- 2到3的最短路径为:2->3,长度为1。
- 2到4的最短路径为:2->3->4,长度为2。
最后回到题目,我们只需要返回dist中的最大值的即为本题答案。
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
npss, dist, prev = [i+1 for i in range(N)], [float('inf')]*(N+1), [-1]*(N+1)
cur, dist[K] = K, 0
while len(npss):
for edge in times:
if edge[0] == cur and edge[1] in npss and dist[cur]+edge[2] < dist[edge[1]]:
dist[edge[1]] = dist[cur]+edge[2]
prev[edge[1]] = cur
if cur not in npss: return -1
npss.remove(cur)
minn = float('inf')
for pnt in npss:
if dist[pnt] < minn:
minn = dist[pnt]
cur = pnt
return max(dist[1:])
需要注意的是,题目并不能保证所有节点都收到信号。