[图算法系列] —— dijkstra

691 阅读2分钟

考虑一个有权图,一个源点,dj算法可以求出图中的所有点到源点的最短路径。

为了实现dj算法,我们需要维护3个数据结构,一个用于记录未确定到源点最短路径的点集upss,一个用于记录源点到该点的最短路径长度dist,一个用于记录最短路径中该点的上一个点prev。

此处以leetcode 743.网络延迟时间为例,源点为2,初始数据结构如下。

upssdistprev
1inf-1
20-1
3inf-1
4inf-1

初始当前节点为2.我们从upss中找到当前节点的一步可达节点1、3,如果经过当前节点到达1、3的距离小于现有源点到1、3的距离,就更新dist和prev,并将2从upss移除。

upssdistprev
112
20-1
312
4inf-1

找到upss中目前距离源点最近的节点1作为当前节点,从upss中找到当前节点的一步可达节点,没有找到,将1从upss移除。

upssdistprev
112
20-1
312
4inf-1

找到upss中目前距离源点最近的节点3作为当前节点,从upss中找到当前节点的一步可达节点4,如果经过当前节点到达4的距离小于现有源点到4的距离,就更新dist和prev,并将3从upss移除。

upssdistprev
112
20-1
312
423

找到upss中目前距离源点最近的节点4作为当前节点,从upss中找到当前节点的一步可达节点,没有找到,将4从upss移除。

upssdistprev
112
20-1
312
423

upss为空,循环终止。dist记录了对应点到源点的最短路径长度,prev记录了最短路径经过的点。

  • 2到1的最短路径为:2->1,长度为1。
  • 2到3的最短路径为:2->3,长度为1。
  • 2到4的最短路径为:2->3->4,长度为2。

最后回到题目,我们只需要返回dist中的最大值的即为本题答案。

def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
    npss, dist, prev = [i+1 for i in range(N)], [float('inf')]*(N+1), [-1]*(N+1)
    cur, dist[K] = K, 0
    while len(npss):
        for edge in times:
            if edge[0] == cur and edge[1] in npss and dist[cur]+edge[2] < dist[edge[1]]:
                dist[edge[1]] = dist[cur]+edge[2]
                prev[edge[1]] = cur
        if cur not in npss: return -1
        npss.remove(cur)
        minn = float('inf')
        for pnt in npss:
            if dist[pnt] < minn:
                minn = dist[pnt]
                cur = pnt
    return max(dist[1:])

需要注意的是,题目并不能保证所有节点都收到信号。