并查集概念及用法分析

4,380 阅读17分钟

这篇文章主要会一步步分析并回答下面的三个问题

什么是并查集?
并查集有什么优化?
并查集可以用来解决什么样的问题?

最后文章会给出 LeetCode 相关问题的分析和总结


什么是并查集

并查集可以看作是一个数据结构,如果你根本没有听说过这个数据结构,那么你第一眼看到 “并查集” 这三个字的时候,脑海里会浮现一个什么样的数据结构呢?基于我们之前所学的知识来思考并推导一个问题,这相比直接去理解,你会收获得更多。

我们就来逐字拆解一下,并、查、集 这个三个字,其中前面两个字都是动词,第三个字是个名词。我们先看名词,因为只有知道了这个东西是什么,才能去理解它能干什么。集就是集合,我们中学时期就学过这个东西,集合用大白话说就是将一堆元素没有顺序地摆放在同一个地方。现在我们已经知道了其实并查集本质就是集合,那它能做什么呢?这就要看前两个字 - “并” 和 “查”,集合的一些操作想必你肯定记得,例如,交集,并集等等,这里的 “并” 其实就指的是并集操作,两个集合合并后就会变成一个集合,如下所示:

{1,3,5,7} U {2,4,6,8} = {1,2,3,4,5,6,7,8}

那 “查” 又是什么呢?集合本身只是容器,我们最终还是要知道里面存的是什么元素,因此这里的 “查” 是对于集合中存放的元素来说的,我们不仅要查找这个元素在不在集合中,我们还要确定这个元素是在哪个集合中,对于前一个操作,Java 中普通的 set 就可以做到,但是对于后一个操作,仅仅使用一个 set,比较难做到。好了,现在你应该知道并查集是什么,以及它能干什么了,总结下来就是:

  • 并查集可以进行集合合并的操作(并)
  • 并查集可以查找元素在哪个集合中(查)
  • 并查集维护的是一堆集合(集)

知道了这些后,并查集的概念就清楚了。


并查集的实现

仅仅靠嘴说,还是差点意思,作为程序员,必然是要将我们了解的东西用代码的形式呈现出来,相信通过上面的表述,你已经知道,并查集维护的是一堆集合而不是一个集合,用什么样的数据结构表示并查集?set 吗?这里有两个东西我们是必须要知道的,元素的值,集合的标号,一个元素仅可能同时存在于一个集合中,元素对集合是多对一的关系,这么看来我们可以用一个健值对的结构来表示并查集,Map 是肯定可以,但是如果对元素本身没有特定要求的话,我们可以使用数组,这样速度更快,使用起来也更加简单,可以看看下面的例子:

{0}, {1}, {2}, {3}, {4}, {5} => [0,1,2,3,4,5]
{0,1,2}, {3,4,5} => [0,0,0,3,3,3] or [1,1,1,4,4,4] ...

在解释上面的数组表示方式之前,不知道你有没有发现一个事实就是,“元素本身的值是固定不变的,但是元素所属的集合是可以变化的”,因此我们可以使用数组的 index 来代表元素,数组 index 上存放的值表示元素所属的集合,另外一个问题就是,集合怎么表示,标号吗?最直接的办法就是就地取材,我们直接从集合中选出一个元素来代表这个集合。相信到这里,你心里还是有存留一堆问题,不急,我们接着看。

说完了集合的表示,我们来看看如何基于这种表示去实现 “并” 和 “查”,也就是集合的合并和元素的查找,这两个操作是相互影响的,因此最好是放在一起讲,合并其实就是改变数组中存放的值,这个值表示的是该元素(index 表示)所在的集合,但是这里有一个问题就是一个集合合并到另一个集合中,我们是不是需要把集合中所有的元素对应的值都更改掉,其实是不需要的,举个例子你就理解了:

{0,1,2}, {3,4}, {5,6} => [1,1,1,3,3,6,6]
{3,4} U {5,6} => {0,1,2}, {3,4,5,6} => [1,1,1,6,3,6,6]
{0,1,2} U {3,4,5,6} => {0,1,2,3,4,5,6} => [1,6,1,6,3,6,6]

仔细观察,你会发现每次合并操作,我们仅仅更新了代表元素所对应的集合,并没有把整个集合里的元素对应的集合都更改了。“代表元素(根元素)” 也就是代表集合的元素,这个元素所在位置存放的就是其本身,两个集合和并的时候,我们仅仅需要更新代表元素即可,因为查找元素所在集合的过程其实就是查找代表元素的过程,代表元素变了,其余元素对应的集合也就变了。依据上面合并的例子,我们来看看如何查找:

[1,1,1,3,3,6,6] 查找元素 4 所在集合 4 -> 3, 元素 4 在集合 3 里面
[1,1,1,6,3,6,6] 查找元素 4 所在集合 4 -> 3 -> 6,元素 4 在集合 6 里面
[1,6,1,6,3,6,6] 查找元素 0 所在集合 0 -> 1 -> 6, 元素 0 在集合 6 里面

我们在来看看代码的实现,首先是查找:

public int find(int element) {
    if (roots[element] == element) {
        return element;
    }
    
    int father = element;
    while (roots[father] != father) {
        father = roots[father];
    }
    
    return father;
}

为了让代码更加简洁,我们也可以写成递归的形式:

public int find(int element) {
    if (roots[element] == element) {
        return element;
    }
    
    return find(roots[element]);
}

查找就是查找代表元素的过程。

另外就是合并,当两个元素相遇,我们合并是将这两个元素所在的集合进行合并,因此我们依然要借助 find 找到这两个元素所在的集合,如果是相同的集合就不需要合并,不同的集合,就将其中一个代表元素进行更改,使其指向另一个代表元素:

public void union(int element1, int element2) {
    int root1 = find(element1);
    int root2 = find(element2);
    
    if (root1 != root2) {
        roots[root1] = root2;
    }
}

不知道你有没有思考这个地方为什么数组的名字要命名成 roots,画画图,你会发现并查集其实可以看成是多棵树,也就是森林,只不过这些树的边的方向是从 child 到 parent 的。形象一点解释就是倒着的树。


优化 - 路径压缩

我们可以分析一下上面的代码的时间复杂度,上面的两个函数操作都是基于数组的,其中 union 操作又是依赖于 find 的,因此 find 操作的时间复杂度等同于并查集操作的时间复杂度。首先我们来看一个例子:

{0}, {1}, {2}, {3}, ... {n} => [0,1,2,3,...,n]
{0} U {1} => {1,1,2,3,...,n}
{1} U {2} => {1,2,2,3,...,n}
{2} U {3} => {1,2,3,3,...,n}
...

上面一步步合并,到最后 find(1) 的时间复杂度是 O(n) 的,find 操作的最差时间是 O(n),有没有办法优化呢?有一个路径压缩的思路,还是上面的例子,上面的例子我们最后得到的是一个长长的搜索链:

1 -> 2 -> 3 -> ... -> n

优化的思路就是让这个链变短,如果我们 find(1) 的话,到最后我们可以找到 “代表元素”,但是从 1 到 “代表元素” 中间的所有元素其实都是同一个集合的,从它们开始 find 也都是找的同一个代表元素,那我们是不是可以将 find(1) 找到的代表元素赋值给中间经过的这些元素呢?这样下次,从这些元素直接就可以找到代表元素。也就是将上面的链变成:

1 -> n
2 -> n
3 -> n
...

find 函数优化后就会是:

public int find(int element) {
    if (root[element] == element) {
        return element;
    }

    // 找到代表元素
    int father = element;
    while (root[father] != father) {
        father = root[father];
    }
    
    // 将路径上所有进过的元素都直接指向代表元素
    int compressPointer = element;
    while (root[compressPointer] != father) {
        int tmp = root[compressPointer];
        root[compressPointer] = father;
        compressPointer = tmp;
    }

    return father;
}

写成递归的形式会比较简洁:

private int find(int element) {
    if (roots[element] == element) {
        return element;
    }
    
    return roots[element] = find(roots[element]);
}

find 经过路径压缩后时间复杂度会变成 O(1),可能你不太理解为什么这里的时间复杂度是 O(1),代码当中明明就有 while 循环啊。可以思考一下动态数组的扩容,动态数组就像是 Java 中的 ArrayList 和 C++ 中的 vector,这些动态数组是基于静态数组实现的,一开始的大小不会太大,如果元素装满了,它就会重新开辟一个为原来两倍大小的静态数组,然后把之前的值都拷贝到新的数组中,这一步的 add 操作是 O(n) 的,但是将其均摊到前面 n - 1 步 add 操作,时间复杂度还是 O(1)。同样,思考并查集的时间复杂度的时候也可以往这个方向去想,基于之前的例子 find(1) 一次 O(n) 的操作过后,find(1), find(2), find(3) ... 都变成 O(1) 了。

我并不想从数学理论上面去证明,这个非常复杂,有些论文证明说并查集时间复杂度最坏会到 O(logn),也有些论文说时间复杂度是 O(an),这里的 a 是一个极小极小的数,这两种表示方式都没有错,但是并查集退化到 O(logn) 的几率会非常非常的小,只有非常极端的例子才会出现,仅仅记住一个时间复杂度对于我们理解算法本身并没有特别大的帮助。举个例子,快速排序 作为当今最伟大的 十大算法 之一,我们总说快排的时间复杂度是 O(nlgn),你是否了解过这个算法最差的时间复杂度是 O(n^2) 的,而且不稳定,归并排序稳定且最差时间复杂度是 O(nlgn),但是相比于快排,还是会逊色不少。可见分析一个算法的性能还是不能仅仅看最差时间复杂度。

如果我们把前面讲到的东西拼接在一起,我们就可以组合成一个真正的并查集结构:

private class UnionFind {
    private int[] roots;
    
    private UnionFind(int size) {
        this.roots = new int[size];
        
        for (int i = 0; i < size; ++i) {
            roots[i] = i;
        }
    }
    
    private int find(int element) {
        if (roots[element] == element) {
            return element;
        }
        
        return roots[element] = find(roots[element]);
    }
    
    private void union(int element1, int element2) {
        int root1 = find(element1);
        int root2 = find(element2);
        
        if (root1 != root2) {
            roots[root1] = root2;
        }
    }
}

并查集还有一个优化叫做启发式合并,就是在 union 操作上优化,之前说过并查集可以看作是一堆倒着的树,这个优化主要是考虑树的深度,合并的时候需要将深度小的树连到深度大的树上面去,因为这个优化对时间的影响并没有路径压缩这么大,因此这里跳过,有兴趣可以了解一下,对于一般的问题,使用路径压缩就完全够了。


并查集可以用来解决什么问题

并查集往往用于解决图上的问题,并查集只有两个操作,“并” 和 “查”,但是通过这两个操作可以派生出一些其他的应用:

  • 图的连通性问题
  • 集合的个数
  • 集合中元素的个数

图的连通性很好理解,一个图是不是连通的是指,“如果是连通图,那么从图上的任意节点出发,我们可以遍历到图上所有的节点”, 这里我们只需要将在图上的节点放到相同的集合中去,然后去看是不是所有的节点均指向同一个集合即可;集合的个数就是找代表元素的个数,查找某个集合中元素的个数最简单的方式就是直接遍历 roots 数组,然后挨个 find,另外一种方法是在结构中多保存一个数组用来记录每个集合中元素的个数,并根据具体的操作来更改。

反过来我们也要思考一个问题就是,什么问题是并查集所不能解决的?并查集的合并操作是不可逆的,你可以理解成只合不分,也就是说两个集合合并之后就不会再分开来了,另外并查集只会保存并维护集合和元素的关系,至于元素之间的关系,比如图上节点与节点的边,这种信息并查集是不会维护的,如果遇到题目让你分析诸如此类的问题,那么并查集并不是一个好的出发点,你可能需要往其他的算法上去考虑。

下面我们就结合实际的 LeetCode 题目来分析看看如何利用并查集解决一些算法题。


并查集相关问题

LC 323. Number of Connected Components in an Undirected Graph

题目分析

给一个图,让你找出图上的连通区域,这里我们只需要利用并查集的合并操作将连在一起的点放在同一个集合里面,最后统计一下 roots 数组中有多少个 代表元素,也就是有多少个集合。

参考代码

private class UnionFind {
    private int[] roots;
    
    private UnionFind(int size) {
        roots = new int[size];
        
        for (int i = 0; i < size; ++i) {
            roots[i] = i;
        }
    }
    
    private int find(int element) {
        if (roots[element] == element) {
            return element;
        }
        
        return roots[element] = find(roots[element]);
    }
    
    private void union(int element1, int element2) {
        int root1 = find(element1);
        int root2 = find(element2);
        
        if (root1 != root2) {
            roots[root1] = root2;
        }
    }
}

public int countComponents(int n, int[][] edges) {
    if (edges == null && edges.length == 0) {
        return n;
    }
    
    UnionFind unionFind = new UnionFind(n);
    
    for (int i = 0; i < edges.length; ++i) {
        unionFind.union(edges[i][0], edges[i][1]);
    }
    
    int[] roots = unionFind.roots;
    
    int count = 0;
    
    for (int i = 0; i < roots.length; ++i) {
        if (roots[i] == i) {
            count++;
        }
    }
    
    return count;
}

LC 547. Friend Circles

题目分析

题目要你找到输入矩阵表示了多少个朋友圈,A 和 B 是朋友,B 和 C 是朋友,A 和 C 就算是间接地朋友,然后 A,B,C 构成了一个朋友圈。解决这道题也非常简单,把具有朋友关系的两个人合并在一起,然后查找有多少个朋友圈(集合)即可,思路和上面那道题目类似。

参考代码

private class UnionFind {
    private int[] roots;
    
    private UnionFind(int size) {
        this.roots = new int[size];
    }

    private int find(int element) {
        while (element == roots[element]) {
            return element;
        }
        
        return roots[element] = find(roots[element]);
    }

    private void union(int element1, int element2) {
        int root1 = find(element1);
        int root2 = find(element2);
        
        if (root1 != root2) {
            roots[root1] = root2;                
        }
    }
}

public int findCircleNum(int[][] M) {
    if (M.length == 0) {
        return 0;
    }
    
    int n = M.length;
    
    UnionFind unionFind = new UnionFind(n);
    
    for (int i = 0; i < n; ++i) {
        unionFind.roots[i] = i;
    }
    
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < i; ++j) {
            if (M[i][j] == 1) {
                unionFind.union(i, j);
            }
        }
    }
    
    int count = 0;
    for (int i = 0; i < n; ++i) {
        if (unionFind.roots[i] == i) {
            count++;
        }
    }
    
    return count;
}

LC 305. Number of Islands II

题目分析

Number of Islands 的变形题目,一开始没有任何的岛屿,每一个时刻会填一格子让其变成陆地,题目问在每个时刻地图上有多少个连通陆地。这道题目是可以用传统的搜索,比如深度优先搜索来解决的,问题是时间复杂度是多少?每次遍历矩阵求岛屿个数时间复杂度是 O(m*n),这里的 m 和 n 分别代表矩阵的长宽,这里假设有 k 个时刻,那么整个问题的时间复杂度就会变成 O(k*m*n)。

这里我们需要明确下面这几点:

  • 每填充一块陆地,岛屿数量加 1
  • 每合并两个陆地,岛屿数量减 1
  • 每块陆地只能和其 上、下、左、右四块陆地合并

不知道你能不能从上面的分析中看出并查集的影子,你可以把陆地看成是集合中的元素,岛屿是集合,合并操作就是将两个岛屿打通,将两个岛屿合并,和之前的题目不一样的是这里我们还需要用一个变量表示当前岛屿的数量。

这里有一个小技巧用于将二维矩阵转换成一维数组来方便并查集计算,矩阵中任意元素都有一个二维坐标 (i, j),那么对应到一维中的 index 就会是 index = i * n + j,这里 n 表示矩阵中每一行元素的个数。

我们再来看看并查集的时间复杂度,一开始初始化 roots 数组需要遍历所有格子,时间复杂度是 O(m*n),有 k 次操作,每次操作 O(1),那么整体时间复杂度就是 O(k + m*n),可以看到,使用并查集将之前搜索的时间复杂度降低了一个维度。

参考代码

private class UnionFind {
    private int[] roots;
    private int count;
    
    private UnionFind(int size) {
        this.roots = new int[size];
        
        Arrays.fill(roots, -1);
        
        this.count = 0;
    }
    
    private int find(int element) {
        if (roots[element] == element) {
            return element;
        }
        
        return roots[element] = find(roots[element]);
    }
    
    private void union(int element1, int element2) {
        int root1 = find(element1);
        int root2 = find(element2);
        
        if (root1 != root2) {
            roots[root1] = root2;
            this.count--;
        }
    }
    
    private int getCount() {
        return this.count;
    }
    
    private void setFather(int element, int father) {
        this.roots[element] = father;
        this.count++;
    }
}

public List<Integer> numIslands2(int m, int n, int[][] positions) {
    List<Integer> results = new ArrayList<>();
    
    if (positions == null || positions.length == 0) {
        return results;
    }
    
    UnionFind unionFind = new UnionFind(m * n);
    
    int[] dirX = {0, 0, -1, 1};
    int[] dirY = {-1, 1, 0, 0};
    
    for (int i = 0; i < positions.length; ++i) {
        int currentId = positions[i][0] * n + positions[i][1];
        
        if (unionFind.roots[currentId] != -1) {
            results.add(unionFind.count);
            continue;
        }
        
        unionFind.setFather(currentId, currentId);
        
        for (int j = 0; j < 4; ++j) {
            int neighbourX = positions[i][0] + dirX[j];
            int neighbourY = positions[i][1] + dirY[j];
            
            if (neighbourX < 0 || neighbourX >= m || neighbourY < 0 || neighbourY >= n) {
                continue;
            }
            
            int neighbourId = neighbourX * n + neighbourY;
            
            if (unionFind.roots[neighbourId] != -1) {
                unionFind.union(currentId, neighbourId);
            }
        }
        
        results.add(unionFind.getCount());
    }
    
    return results;
}

LC 130. Surrounded Regions

题目分析

题目要求将被 X 完全包围住的 O 改成 X,这道题比较 ticky 的地方是,最外层的 O 是没有被包围的,如果内部的 O 和最外层的 O 相连,那么内层的 O 也是不会被改写的。比较直接的思路是先对最外层的 O 做预处理,这里我们用到一个类似 “哨兵节点” 的思想,我们开并查集数组的时候多开一格,用最后一个节点来表示一个集合,在这个集合里面的元素不会被改写,因此我们首先将最外圈的 O 放到这个集合中,然后遍历数组进行和并操作,合并操作我们规定标号小的集合必须放在大的集合中,保证哨兵节点所代表的集合一直都在,最后将不在哨兵集合中的 O 改写成 X 即可。

参考代码

private class UnionFind {
    private int[] roots;
    
    private UnionFind(int size) {
        this.roots = new int[size];
        
        for (int i = 0; i < size; ++i) {
            this.roots[i] = i;
        }
    }
    
    private int find(int element) {
        if (roots[element] == element) {
            return element;
        }
        
        return roots[element] = find(roots[element]);
    }
    
    private void union(int element1, int element2) {
        int root1 = find(element1);
        int root2 = find(element2);
        
        if (root1 != root2) {
            if (root2 > root1) {
                roots[root1] = root2;
            } else {
                roots[root2] = root1;
            }
        }
    }
}

public void solve(char[][] board) {
    if (board == null || board.length == 0 || board[0] == null || board[0].length == 0) {
        return;
    }
    
    int m = board.length, n = board[0].length;
    
    UnionFind unionFind = new UnionFind(m * n + 1);
    
    for (int i = 0; i < m; ++i) {
        if (board[i][0] == 'O') {
            unionFind.union(i * n, m * n);
        }
        
        if (board[i][n - 1] == 'O') {
            unionFind.union((i + 1) * n - 1, m * n);
        }
    }
    
    for (int i = 0; i < n; ++i) {
        if (board[0][i] == 'O') {
            unionFind.union(i, m * n);
        }
        
        if (board[m - 1][i] == 'O') {
            unionFind.union((m - 1) * n + i, m * n);
        }
    }
    
    int[] dirX = {0, 0, -1, 1};
    int[] dirY = {-1, 1, 0, 0};
    
    for (int i = 1; i < m - 1; ++i) {
        for (int j = 1; j < n - 1; ++j) {
            if (board[i][j] != 'O') {
                continue;
            }
            
            int currentId = i * n + j;
            
            for (int k = 0; k < 4; ++k) {
                int neighborX = i + dirX[k];
                int neighborY = j + dirY[k];
                
                if (board[neighborX][neighborY] == 'X') {
                    continue;
                }
                
                int neighborId = neighborX * n + neighborY;
                
                unionFind.union(neighborId, currentId);
            }
        }
    }
    
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            int currentId = i * n + j;

            if (unionFind.find(currentId) != m * n) {
                board[i][j] = 'X';
            }
        }
    }
}

总结

并查集的题目还有很多,这里就不一一列举,如果一道题目中涉及的条件可以抽象成集合和元素这两个概念,那么这道题差不多就可以用并查集来解决,如果定义清楚 find 以及 union 操作所表示的意义,那么整个问题会变得非常简单。以上是这次的分享,希望对你有所帮助。