并查集

96 阅读5分钟
  • 基础问题选讲

    • 990. 等式方程的可满足性

      class Solution:
          def equationsPossible(self, equations: List[str]) -> bool:
              ch2idx = {}
              idx = 0
              for equation in equations:
                  u, v = equation[0], equation[3]
                  if u not in ch2idx:
                      ch2idx[u] = idx
                      idx += 1
                  if v not in ch2idx:
                      ch2idx[v] = idx
                      idx += 1
      
              roots = list(range(idx))
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              for equation in equations:
                  if equation[1] == "=":
                      u, v = equation[0], equation[3]
                      idx_u, idx_v = ch2idx[u], ch2idx[v]
                      union(idx_u, idx_v)
      
              for equation in equations:
                  if equation[1] == "!":
                      u, v = equation[0], equation[3]
                      idx_u, idx_v = ch2idx[u], ch2idx[v]
                      if find(idx_u) == find(idx_v):
                          return False
              return True
      
    • 547. 省份数量

      class Solution:
          def findCircleNum(self, isConnected: List[List[int]]) -> int:
              n = len(isConnected)
              roots = list(range(n+1))
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              for i in range(n):
                  for j in range(n):
                      if isConnected[i][j] == 1:
                          union(i+1, j+1)
              return sum([1 for i in range(1, n+1) if roots[i] == i])
      
    • 684. 冗余连接

      class Solution:
          def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
              n = len(edges)
              roots = list(range(n+1))
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              for u, v in edges:
                  if find(u) != find(v):
                      union(u, v)
                  else:
                      return [u, v]
      
    • 1319. 连通网络的操作次数

      class Solution:
          def makeConnected(self, n: int, connections: List[List[int]]) -> int:
              conn = len(connections)
              if conn < n-1:
                  return -1
      
              # 用并查集判断有几个独立的集合s,然后返回s-1
              roots = list(range(n))
              def find(x: int):
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              for u, v in connections:
                  union(u, v)
      
              groupCnt = 0
              for i, v in enumerate(roots):
                  if find(v) == i:
                      groupCnt += 1
              return groupCnt-1
      
    • 765. 情侣牵手

      class Solution:
          def minSwapsCouples(self, row: List[int]) -> int:
              n = len(row)//2
              roots = list(range(n))
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              # 编号为row[i]的情侣,其向下取整row[i]//2即为其本应所属的情侣组号
              for i in range(0, 2*n, 2):
                  u, v = row[i]//2, row[i+1]//2
                  union(u, v)
      
              # 一个group包含所有首尾相接坐错了的多对情侣,没思路时可先用小数据量画一下找规律
              groupCnt = 0
              for i in range(n):
                  if roots[i] == i:
                      groupCnt += 1
              return n - groupCnt
      
  • 进阶问题选讲

    • 399. 除法求值

      class Solution:
          def calcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]:
              str2idx = {}
              idx = 0
              for u, v in equations:
                  if u not in str2idx:
                      str2idx[u] = idx
                      idx += 1
                  if v not in str2idx:
                      str2idx[v] = idx
                      idx += 1
      
              # 带权并查集
              roots = [(i, 1.0) for i in range(idx)]
              def find(x: int) -> Tuple[int, float]:
                  rx, wx = roots[x]
                  if rx != x:
                      r, w = find(rx)
                      roots[x] = (r, wx * w)
                  return roots[x]
      
              def union(u: int, v: int, w: float):
                  ru, wu, rv, wv = *find(u), *find(v)
                  roots[ru] = (rv, wv / wu * w)
      
              for (x, y), weight in zip(equations, values):
                  ix, iy = str2idx[x], str2idx[y]
                  union(ix, iy, weight)
      
              ans = []
              for x, y in queries:
                  if x not in str2idx or y not in str2idx:
                      ans.append(-1.0)
                      continue
                  ix, iy = str2idx[x], str2idx[y]
                  rx, wx, ry, wy = *find(ix), *find(iy)
                  if rx == ry:
                      ans.append(wx / wy)
                  else:
                      ans.append(-1.0)
              return ans
      
    • 959. 由斜杠划分区域

      class Solution:
          def regionsBySlashes(self, grid: List[str]) -> int:
              n = len(grid)
              roots = {(i,j): (i,j) for i in range(n+1) for j in range(n+1)}
              def find(x: Tuple[int,int]) -> Tuple[int,int]:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: Tuple[int,int], v: Tuple[int,int]):
                  roots[find(u)] = find(v)
      
              for i in range(n+1):
                  union((0,0), (i,0))
                  union((0,0), (i,n))
                  union((0,0), (0,i))
                  union((0,0), (n,i))
      
              # 如果一条待添加线段的两个端点root相等,说明两个点已是同属一个集合内
              # 此时再添加一条两点间的线,会导致切分出一个新区域
              ans = 1
              for i in range(n):
                  for j in range(n):
                      if grid[i][j] == "/":
                          r1, r2 = find((i+1,j)), find((i,j+1))
                          if r1 == r2:
                              ans += 1
                          union((i+1,j), (i,j+1))
                      elif grid[i][j] == "\\":
                          r1, r2 = find((i,j)), find((i+1,j+1))
                          if r1 == r2:
                              ans += 1
                          union((i,j), (i+1,j+1))
              return ans
      
    • 778. 水位上升的泳池中游泳

      class Solution:
          def swimInWater(self, grid: List[List[int]]) -> int:
              n = len(grid)
              roots = {(i,j): (i,j) for i in range(n) for j in range(n)}
              def find(x: Tuple[int,int]) -> Tuple[int,int]:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: Tuple[int,int], v: Tuple[int,int]):
                  roots[find(u)] = find(v)
      
              dr = [-1,1,0,0]
              dc = [0,0,-1,1]
              elevation2coordinates = {grid[i][j]: (i,j) for i in range(n) for j in range(n)}
              for elevation in range(n**2):
                  (r, c) = elevation2coordinates[elevation]
                  for i in range(4):
                      nr, nc = r + dr[i], c + dc[i]
                      if 0 <= nr < n and 0 <= nc < n and grid[nr][nc] <= elevation:
                          union((r,c), (nr,nc))
                  if find((0,0)) == find((n-1,n-1)):
                      return elevation
              return -1
      
    • 1202. 交换字符串中的元素

      class Solution:
          def smallestStringWithSwaps(self, s: str, pairs: List[List[int]]) -> str:
              n = len(s)
              roots = list(range(n))
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              for u, v in pairs:
                  union(u, v)
      
              groups = defaultdict(list)
              for i in range(n):
                  groups[find(i)].append(i)
      
              ans = [None] * n
              for charIdxGroup in groups.values():
                  tmpStr = sorted([s[idx] for idx in charIdxGroup])
                  i = 0
                  for idx in sorted(charIdxGroup):
                      ans[idx] = tmpStr[i]
                      i += 1
              return "".join(ans)
      
    • 947. 移除最多的同行或同列石头

      class Solution:
          def removeStones(self, stones: List[List[int]]) -> int:
              n = len(stones)
              roots = list(range(n))
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              def union(u: int, v: int):
                  roots[find(u)] = find(v)
      
              xCoordinate2idx, yCoordinate2idx = {}, {}
              for i in range(n):
                  x, y = stones[i]
                  if x not in xCoordinate2idx:
                      xCoordinate2idx[x] = i
                  else:
                      union(i, xCoordinate2idx[x])
                  if y not in yCoordinate2idx:
                      yCoordinate2idx[y] = i
                  else:
                      union(i, yCoordinate2idx[y])
      
              groups = defaultdict(list)
              for i in range(n):
                  groups[find(i)].append(i)
      
              ans = 0
              for groupList in groups.values():
                  ans += len(groupList) - 1
              return ans
      
    • 803. 打砖块

      class Solution:
          def hitBricks(self, grid: List[List[int]], hits: List[List[int]]) -> List[int]:
              m, n = len(grid), len(grid[0])
              # m*n是一个特殊下标,指向稳定的砖块集合,即和天花板直接或间接连着的节点
              roots = list(range(m*n + 1))
              ranks = [0] * (m*n + 1)
              sz = [1] * (m*n + 1)
              def find(x: int) -> int:
                  if roots[x] != x:
                      roots[x] = find(roots[x])
                  return roots[x]
      
              # 这里是按秩归并,rank较小的集合归并到rank大的里面
              def union(x: int, y: int):
                  xr, yr = find(x), find(y)
                  if xr == yr:
                      return
                  if ranks[xr] < ranks[yr]:
                      xr, yr = yr, xr
                  if ranks[xr] == ranks[yr]:
                      ranks[xr] += 1
                  roots[yr] = xr
                  sz[xr] += sz[yr]
      
              def size(x: int) -> int:
                  return sz[find(x)]
      
              # 返回和天花板连接的节点的数量,减去特殊节点m*n自身
              def top() -> int:
                  return size(len(sz) - 1) - 1
      
              def getIdx(r: int, c: int) -> int:
                  return r * n + c
      
              # 逆向思维,先把hits坐标处的砖块都敲掉,然后反向添加砖块,并检查连通分量
              # 逆向考虑,同见 174. 地下城游戏 和 312. 戳气球
              A = [row[:] for row in grid]
              for i, j in hits:
                  A[i][j] = 0
      
              for r, row in enumerate(A):
                  for c, val in enumerate(row):
                      if val == 1:
                          idx = getIdx(r, c)
                          if r == 0:
                              union(idx, m*n)
                          if r > 0 and A[r-1][c]:
                              union(idx, getIdx(r-1, c))
                          if c > 0 and A[r][c-1]:
                              union(idx, getIdx(r, c-1))
      
              dr = [-1,1,0,0]
              dc = [0,0,-1,1]
              ans = []
              for r, c in reversed(hits):
                  pre_roof = top()
                  if grid[r][c] == 0:
                      ans.append(0)
                  else:
                      idx = getIdx(r, c)
                      for i in range(4):
                          nr, nc = r + dr[i], c + dc[i]
                          if 0 <= nr < m and 0 <= nc < n and A[nr][nc]:
                              union(idx, getIdx(nr, nc))
                      if r == 0:
                          union(idx, m*n)
                      A[r][c] = 1
                      ans.append(max(0, top() - pre_roof - 1))
              return ans[::-1]