图的基础算法1 - MST

743 阅读4分钟

Prim

PQ Implementation - 稀疏图 O(ElogV)

  • 从某一个节点u开始,寻找当前该节点u所有可以访问的edges
  • 在这些edges中找到最小权重edge (u, v),同时这条edge必须有一端v尚未被visited。将这个尚未访问过的节点v加入集合S,记录添加的edge (u, v)
  • 寻找当前集合S所有可以访问的edges,重复step2,直到没有新的节点可以再加入集合S
  • 此时构成的树即为MST

模版:

class Solution:
    def minimumCost(self, n: int, connections: List[List[int]]) -> int:
        graph = [[] for _ in range(n + 1)]
        for u, v, w in connections:
            graph[u].append((v, w))
            graph[v].append((u, w))
        visited = set()
        pq = []
        heapq.heappush(pq, (0, 1, 1))  # (cost, from, to)
        costs = 0
        while pq:
            w, u, v = heapq.heappop(pq)
            if v not in visited:  # 如果有一端尚未visited
                visited.add(v)
                costs += w
                for v_next, w_next in graph[v]:
                    heapq.heappush(pq, (w_next, v, v_next))
        return costs if len(visited) == n else -1

Naive Implementation - 稠密图 O(V^2)

暴力扫一遍所有往出走的边, 找最小

模版:

image.png

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        # 任意两个节点之间都可以被连接 -> 稠密图
        n = len(points)
        matrix = [[0] * n for _ in range(n)]  # matrix[i][j] = i和j之间的"曼哈顿距离"
        for i in range(n):
            for j in range(n):
                matrix[i][j] = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
        visited = [False] * n
        dist = [float("inf")] * n  # dist[i] = i到MST的最短距离
        dist[0] = 0  # 起始点
        for i in range(n):
            nextClose = -1
            for j in range(n):
                # 找距离当前MST最近的一个节点
                if not visited[j] and (nextClose == -1 or dist[j] < dist[nextClose]):
                    nextClose = j
            visited[nextClose] = True  # 加入MST
            for y in range(n):
                if not visited[y]:
                    # 更新所有尚不在MST中节点们到MST的最短距离 = min(dist[y], matrix[nextClose][y])
                    dist[y] = min(dist[y], matrix[nextClose][y])
        return sum(dist)  # total weight of MST


Kruskal - 稀疏图 O(ElogE)

  • 按照edge的weight进行排序(从小到大)
  • 依次将每条edge加入MST,除非这条edge的加入会导致cycle(基于union-find来detect cycle)

模版:

class UF:
    def __init__(self, n):
        self.parent = [i for i in range(n)]
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        self.parent[self.find(x)] = self.find(y)


class Solution:
    def minimumCost(self, n: int, connections: List[List[int]]) -> int:
        connections.sort(key=lambda x: x[2])
        uf = UF(n + 1)
        costs = 0
        union_times = 0  # 理论上,应该要union (n-1)次,因为总共要添加n-1条edges
        for u, v, w in connections:
            if uf.find(u) != uf.find(v):
                uf.union(u, v)
                costs += w
                union_times += 1
        return costs if union_times == n - 1 else -1


1135. 最低成本联通所有城市(Medium)

image.png

Solu 1:Prim

见pq implementation模版,略

Code 1:

class Solution:
    def minimumCost(self, n: int, connections: List[List[int]]) -> int:
        graph = [[] for _ in range(n + 1)]
        for u, v, w in connections:
            graph[u].append((v, w))
            graph[v].append((u, w))
        visited = set()
        pq = []
        heapq.heappush(pq, (0, 1, 1))  # (cost, from, to)
        costs = 0
        while pq:
            w, u, v = heapq.heappop(pq)
            if v not in visited:  # 如果有一端尚未visited
                visited.add(v)
                costs += w
                for v_next, w_next in graph[v]:
                    heapq.heappush(pq, (w_next, v, v_next))
        return costs if len(visited) == n else -1

Solu 2:Kruskal

见Kruskal - union find模版,略

Code 2:

class UF:
    def __init__(self, n):
        self.parent = [i for i in range(n)]
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        self.parent[self.find(x)] = self.find(y)


class Solution:
    def minimumCost(self, n: int, connections: List[List[int]]) -> int:
        connections.sort(key=lambda x: x[2])
        uf = UF(n + 1)
        costs = 0
        union_times = 0  # 理论上,应该要union (n-1)次,因为总共要添加n-1条edges
        for u, v, w in connections:
            if uf.find(u) != uf.find(v):
                uf.union(u, v)
                costs += w
                union_times += 1
        return costs if union_times == n - 1 else -1


1168. 水资源分配优化(Hard)

image.png

Solu 1:“超级源点” + Prim

  • 引入“超级源点” superweight(super, u) = cost(在u上挖井)
    • 在一个节点自身上挖井 = 直接从“超级源点”super取水

Code 1:

class Solution:
    def minCostToSupplyWater(self, n: int, wells: List[int], pipes: List[List[int]]) -> int:
        # build graph
        graph = [[] for _ in range(n + 1)]  # 节点0为dummy node(超级源点),其到每个节点的距离为wells[i]
        # 一个node在自身上挖井 = 直接从"超级源点"取水
        for idx, cost in enumerate(wells):
            graph[0].append((idx + 1, cost))
        for u, v, w in pipes:
            graph[u].append((v, w))
            graph[v].append((u, w))
        # prim with pq
        pq = []
        visited = {0}
        for u, w in graph[0]:
            heapq.heappush(pq, (w, 0, u))  # (cost, from, to)
        costs = 0
        while pq:
            cost, u, v = heapq.heappop(pq)
            if v not in visited:
                visited.add(v)
                costs += cost
                for next, next_cost in graph[v]:
                    if next not in visited:
                        heapq.heappush(pq, (next_cost, v, next))
        return costs

Solu 2:“超级源点” + Kruskal

  • 同上,引入“超级源点”super,将weight(super, u)一道进行排序

Code 2:

class UF:
    def __init__(self, n):
        self.parent = [i for i in range(n)]
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        self.parent[self.find(x)] = self.find(y)


class Solution:
    def minCostToSupplyWater(self, n: int, wells: List[int], pipes: List[List[int]]) -> int:
        for idx, cost in enumerate(wells):
            pipes.append([0, idx + 1, cost])
        pipes.sort(key=lambda x: x[2])
        uf = UF(n + 1)
        res = 0
        for u, v, w in pipes:
            if uf.find(u) != uf.find(v):
                res += w
                uf.union(u, v)
        return res


1584. 连接所有点的最小费用(Medium)

image.png

image.png

Solu:

见naive implementation模版,略

Code:

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        # 任意两个节点之间都可以被连接 -> 稠密图
        n = len(points)
        matrix = [[0] * n for _ in range(n)]  # matrix[i][j] = i和j之间的"曼哈顿距离"
        for i in range(n):
            for j in range(n):
                matrix[i][j] = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
        visited = [False] * n
        dist = [float("inf")] * n  # dist[i] = i到MST的最短距离
        dist[0] = 0  # 起始点
        for i in range(n):
            nextClose = -1
            for j in range(n):
                # 找距离当前MST最近的一个节点
                if not visited[j] and (nextClose == -1 or dist[j] < dist[nextClose]):
                    nextClose = j
            visited[nextClose] = True  # 加入MST
            for y in range(n):
                if not visited[y]:
                    # 更新所有尚不在MST中节点们到MST的最短距离 = min(dist[y], matrix[nextClose][y])
                    dist[y] = min(dist[y], matrix[nextClose][y])
        return sum(dist)  # total weight of MST


References: