多路归并(最小堆)

24 阅读7分钟

1. 基本概念

多路归并是将多个有序序列合并成一个有序序列的问题。经典例子是合并 K 个有序链表/数组。(注意,在合并之前,各个list已经排好序了)

是一个具体的场景问题

多路归并(k-way merge)  是一种在多个已排序序列中合并出一个整体有序序列的算法技术。它是经典“归并”思想的扩展,常用于外部排序、优先队列实现、流数据处理等场景。

  • 两路归并(2-way merge) :最常见形式,比如归并排序中的合并两个已排序子数组。
  • 多路归并(k-way merge) :将 k 个(k ≥ 2)已排序的序列 合并成一个有序序列。

例如:

输入:
[1, 4, 7]
[2, 5, 8]
[3, 6, 9]

输出:
[1, 2, 3, 4, 5, 6, 7, 8, 9]

2. 实现方法

方法1:暴力法(不推荐)

暴力法思路

  • 每次遍历所有 k 个序列的当前“头部”元素(即每个序列中尚未被取走的第一个元素)。
  • 找出其中最小的一个,将其加入结果,并在对应序列中“移除”该元素(通过维护每个序列的读取索引实现)。
  • 重复直到所有元素都被处理完。

时间复杂度:O(N × k),其中 N 是总元素数,k 是序列数量。

空间复杂度:O(1) 额外空间(不计输出)。

package main

import (
    "fmt"
    "math"
)

// kWayMergeBruteForce 使用暴力法实现多路归并
// 输入:lists 是一个二维切片,每个子切片都是升序排列的整数数组
// 输出:合并后的单一升序数组
func kWayMergeBruteForce(lists [][]int) []int {
    k := len(lists)
    if k == 0 {
        return []int{}
    }

    // 初始化每个列表的当前读取位置(指针),预先分配长度k,每个list的起始位置都是0
    pointers := make([]int, k)

    // 计算总元素数,用于预分配结果切片容量,也就是N 
    total := 0
    for _, list := range lists {
        total += len(list)
    }
    
    // 长度为0,预分配容量为total
    result := make([]int, 0, total)

    // 当还有未处理的元素时,继续循环,也就是循环total=k*n(N)次,每一次都是k选1,所以时间复杂度是N*k
    for len(result) < total {
        minVal := math.MaxInt32 // 初始化为最大整数
        minIdx := -1            // 记录哪个列表提供了最小值

        // 遍历所有列表,找出当前可取的最小元素
        for i := 0; i < k; i++ {
            // 如果该列表还没遍历完
            if pointers[i] < len(lists[i]) {  // pointers[k]=第k个list当前的索引,由于pointers[minIdx]++,最后会=len(lists[i])
                if lists[i][pointers[i]] < minVal {
                    minVal = lists[i][pointers[i]]
                    minIdx = i
                }
            }
        }

        // 将最小值加入结果,并移动对应指针
        result = append(result, minVal)
        pointers[minIdx]++
    }

    return result
}

// 示例与测试
func main() {
    lists := [][]int{
        {1, 4, 7},
        {2, 5, 8},
        {3, 6, 9},
    }

    merged := kWayMergeBruteForce(lists)
    fmt.Println("合并结果:", merged) // 输出: [1 2 3 4 5 6 7 8 9]
}

为什么说“暴力法不好”?

时间复杂度过高:O(N × k),例如假设:

  • k = 序列数量(例如 1000 个有序日志流)
  • N = 所有元素总数(例如 1 亿个整数)

那么总操作次数 ≈ 1000 × 1亿 = 1000 亿次比较
而用最小堆优化的方法只需要 O(N log k) ≈ 1亿 × log₂(1000) ≈ 1亿 × 10 = 10 亿次,快了 100 倍

方法2:使用最小堆(优先队列)

先介绍一下,什么是最小堆?

其数据结构是完全二叉树:也就是一种优先从上到下,从左到右的节点分布树

最小堆除了是完全二叉树之外,还要满足一个条件,就是父节点的值,要小于或等于其子节点的值

image.png

添加节点:

添加节点时,为了满足完全二叉树的特性,都是先添加在当前树最下最左的位置,然后和其父节点做比较,从而判断是否要做一次换位,而由于当前的链路已经满足从小到大(13,41,90),所以每次只用考虑换位 image.png

删除节点:

删除节点比较巧妙,每一次都从根节点删除,也就是删除最小值的节点,然后将最后一个节点换上来,然后再向下判断换位

image.png

image.png

ok,了解了最小堆是啥之后,接着看是怎么利用最小堆来处理多路归并的问题!

由于自己写一个最小堆还是比较麻烦,我们可以直接使用现成的,在go语言中,提供了一个包,container/heap,这个包提供了一个堆接口,但它本身并不直接提供一个预定义的最小堆或最大堆。你需要自己实现一个满足 heap.Interface 接口的数据结构(一个切片),然后就可以使用 heap 包提供的函数来操作它。

heap.Interface 接口:这是一个组合接口,继承了 sort.Interface

type Interface interface {
    sort.Interface
    Push(x any) // Add x as element Len()
    Pop() any   // Remove and return element Len() - 1.
}

sort.Interface 本身包含:

type Interface interface {
    Len() int           // 返回元素数量
    Less(i, j int) bool // 比较元素 i 和 j,决定它们的顺序
    Swap(i, j int)      // 交换元素 i 和 j
}

// Less` 方法决定堆类型:
// **最小堆**:`Less(i, j)` 应该返回 `slice[i] < slice[j]`。
// **最大堆**:`Less(i, j)` 应该返回 `slice[i] > slice[j]`。

也就是说定义好一个切片之后,只要对应实现这5方法,就可以使用了

核心思路如下:用一个最小堆来管理所有序列当前待比较的元素,每次只需 O(log k) 时间就能取出全局最小值,并放入下一个来自同序列的候选元素,从而避免了每次都 O(k) 扫描的开销。在k比较大时,效果显著

暴力法 vs 最小堆法

方面暴力法最小堆法
核心操作每次从 k 个序列的当前头部元素中 扫描找最小值每次从一个大小为 k 的 最小堆中取出堆顶(即最小值)
时间复杂度每次取最小值 O(k),总共 N 次 → O(N * k)每次取/放元素 O(log k),总共 N 次 → O(N * log k)
关键优化点用  替换了 线性扫描

为什么更快?

  • 暴力法:每次都需要看 k 个元素才能知道谁最小(O(k) )。
  • 最小堆法:堆这个数据结构本身就维护了“最小元素在顶部”的性质,获取最小值和插入新元素都只要 O(log k)  的时间。对于 N 个总元素,这个优势会放大 N 次。
package main

import (
	"container/heap"
	"fmt"
)

// Item 定义堆中存放的元素(信息要尽可能的全)
type Item struct {
	value int // 元素的值
	index int // 该元素在其所属列表中的索引
	listId int // 该元素所属列表的 ID
}

// IntMinHeap 定义一个 Item 类型的最小堆
type IntMinHeap []*Item

// 实现接口的5个方法
func (h IntMinHeap) Len() int           { return len(h) }
// Less 定义最小堆的规则:值小的优先级高(放在堆顶)
func (h IntMinHeap) Less(i, j int) bool { return h[i].value < h[j].value }
func (h IntMinHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *IntMinHeap) Push(x any) { *h = append(*h, x.(*Item)) }
func (h *IntMinHeap) Pop() any {
	old := *h
	n := len(old)
	item := old[n-1]
	*h = old[0 : n-1]
	return item
}

// kWayMergeWithHeap 使用最小堆实现多路归并
func kWayMergeWithHeap(lists [][]int) []int {
	h := &IntMinHeap{} 
	heap.Init(h) // 关键声明

	// 1. 初始化堆:将每个非空列表的第一个元素放入堆
	for i, list := range lists {
		if len(list) > 0 {
			heap.Push(h, &Item{
				value:  list[0],
				index:  0, // 初始索引为 0
				listId: i, // 记录来自哪个列表
			})
		}
	}

	var result []int

	// 2. 主循环:不断从堆中取最小值,并将下一个元素放入堆
	for h.Len() > 0 {
		// 取出堆顶(当前最小元素)
		minItem := heap.Pop(h).(*Item)
		result = append(result, minItem.value)

		// 检查该元素来源的列表是否还有下一个元素
		nextIndex := minItem.index + 1
		if nextIndex < len(lists[minItem.listId]) {
			// 如果有,则将下一个元素加入堆
			nextValue := lists[minItem.listId][nextIndex]
			heap.Push(h, &Item{
				value:  nextValue,
				index:  nextIndex,
				listId: minItem.listId,
			})
		}
	}

	return result
}

func main() {
	lists := [][]int{
		{1, 4, 7},
		{2, 5, 8},
		{3, 6, 9},
	}

	merged := kWayMergeWithHeap(lists)
	fmt.Println("合并结果:", merged) // 输出: [1 2 3 4 5 6 7 8 9]
}