并查集简介
并查集定义
并查集(Union Find) :一种树型的数据结构,用于处理一些不交集(Disjoint Sets)的合并及查询问题。不交集指的是一系列没有重复元素的集合。
并查集主要支持两种操作:
- 合并(Union) :将两个集合合并成一个集合。
- 查找(Find) :确定某个元素属于哪个集合。通常是返回集合内的一个「代表元素」。
隔代压缩实现代码
class UnionFind:
def __init__(self, n): # 初始化
self.fa = [i for i in range(n)] # 每个元素的集合编号初始化为数组 fa 的下标索引
def find(self, x): # 查找元素根节点的集合编号内部实现方法
while self.fa[x] != x: # 递归查找元素的父节点,直到根节点
self.fa[x] = self.fa[self.fa[x]] # 隔代压缩优化
x = self.fa[x]
return x # 返回元素根节点的集合编号
def union(self, x, y): # 合并操作:令其中一个集合的树根节点指向另一个集合的树根节点
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y: # x 和 y 的根节点集合编号相同,说明 x 和 y 已经同属于一个集合
return False
self.fa[root_x] = root_y # x 的根节点连接到 y 的根节点上,成为 y 的根节点的子节点
return True
def is_connected(self, x, y): # 查询操作:判断 x 和 y 是否同属于一个集合
return self.find(x) == self.find(y)
本文列举了我这次刷的 3 道并查集算法题目。
684. 冗余连接
解题思路
采用并查集判断是否存在冗余边。遍历列表 edges
中的边,对每条边的两个节点进行检查,如果父节点相同,表示两个节点已经连通,则该条边是冗余的。由于连接 n
个节点最少需要 n - 1
条边,而题目中给出 n == edges.length
,因此本题 edges
只冗余一条边, 遍历过程中发现的冗余边直接返回即可。
代码
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
n = len(edges)
fa = list(range(n + 1))
def find(x):
while x != fa[x]:
fa[x] = fa[fa[x]]
x = fa[x]
return x
def union(x, y):
fa[find(x)] = fa[find(y)]
for x, y in edges:
if find(x) != find(y):
union(x, y)
else:
return [x, y]
return []
1319. 连通网络的操作次数
解题思路
题目的含义就是找出 n
个节点的连通分量。首先连通 n
个节点所需最少连接的边个数是 n - 1
个。
通过并查集找出连通的网络的个数。
操作方式是初始网络中未连接节点个数 self.cnt = n
,通过并查集判断 connections
中的边连接了多少节点。每次 union
节点则 self.cnt - 1
,最后返回结果是 self.cnt - 1
。
代码
class Solution:
def makeConnected(self, n: int, connections: List[List[int]]) -> int:
if len(connections) < n - 1:
return -1
fa = list(range(n))
self.cnt = n
def find(x):
while x != fa[x]:
fa[x] = fa[fa[x]]
x = fa[x]
return x
def union(x, y):
root_x, root_y = find(x), find(y)
if root_x == root_y:
return
fa[root_x] = root_y
self.cnt -= 1
for x, y in connections:
union(x, y)
return self.cnt - 1
947. 移除最多的同行或同列石头
解题思路
将石头坐标看成是节点,建立并查集。如果石头同行或同列,则进行合并。根据题意,同一个集合中移除到只剩下1个,因此计算连通分量,移除的个数是 总的石头个数 - 并查集的集合个数
。
代码
class UFS():
def __init__(self, n):
self.p = list(range(n))
def find(self, x):
if self.p[x] != x:
self.p[x] = self.find(self.p[x])
return self.p[x]
def union(self, x, y):
xr = self.find(x)
yr = self.find(y)
self.p[xr] = yr
class Solution:
def removeStones(self, stones: List[List[int]]) -> int:
ufs = UFS(20000)
for x, y in stones:
ufs.union(x, y + 10001)
return len(stones) - len({ufs.find(x) for x, y in stones})