堆与优先队列:Top K问题的完整攻略

前言

堆(Heap)和优先队列(Priority Queue)是处理Top K问题的利器。很多人搞不清楚什么时候用大根堆,什么时候用小根堆,其实记住一个口诀就行:求前K大用小根堆,求前K小用大根堆

我并没有能力让你看完就精通所有堆的应用,我只是想让你理解堆的结构、大小根堆的选择技巧、以及Top K问题的通用模板。

摘要

从"找数组第K大元素"问题出发,剖析堆的核心结构与Top K问题的解题技巧。通过大根堆与小根堆的对比、堆调整的图解演示、以及数据流中位数的双堆解法,揭秘优先队列的巧妙应用。配合LeetCode高频题目,给出Top K问题的完整套路。


一、从找第K大元素说起

周五早上,哈吉米遇到一道题:

LeetCode 215 - 数组中的第K个最大元素

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。

示例:
输入:nums = [3,2,1,5,6,4], k = 2
输出:5

输入:nums = [3,2,3,1,2,4,5,5,6], k = 4
输出:4

哈吉米:"这不就是排序吗?"

Java版本(暴力)

public int findKthLargest(int[] nums, int k) {
    Arrays.sort(nums);
    return nums[nums.length - k];
}

南北绿豆:"可以,但时间复杂度O(nlogn),不是最优。"

阿西噶阿西:"用,可以做到O(nlogk),甚至O(n)。"


二、什么是堆

南北绿豆:"堆是一种特殊的完全二叉树。"

2.1 堆的定义

大根堆(Max Heap)

  • 每个节点的值 ≥ 子节点的值
  • 根节点是最大值

小根堆(Min Heap)

  • 每个节点的值 ≤ 子节点的值
  • 根节点是最小值

图示

flowchart TB
    subgraph 大根堆
        A1["10<br/>最大"]
        B1["8"]
        C1["9"]
        D1["5"]
        E1["6"]
        F1["7"]
        G1["4"]
        
        A1 --> B1
        A1 --> C1
        B1 --> D1
        B1 --> E1
        C1 --> F1
        C1 --> G1
    end
    
    subgraph 小根堆
        A2["1<br/>最小"]
        B2["3"]
        C2["2"]
        D2["5"]
        E2["6"]
        F2["7"]
        G2["4"]
        
        A2 --> B2
        A2 --> C2
        B2 --> D2
        B2 --> E2
        C2 --> F2
        C2 --> G2
    end
    
    style A1 fill:#ffe6e6
    style A2 fill:#e1ffe1

2.2 生活化场景

阿西噶阿西:"堆就像排队看病。"

大根堆(VIP优先)

队列:[病情9, 病情5, 病情7, 病情3]

每次叫号:病情最严重的(9)先看
看完后:[病情7, 病情5, 病情3]
下一个:病情7

小根堆(先来后到)

队列:[等待5分钟, 等待10分钟, 等待8分钟]

每次叫号:等待时间最短的(5分钟)

哈吉米:"堆就是能快速拿到最大/最小值的数据结构。"


三、堆的存储:数组实现

南北绿豆:"堆虽然是树,但用数组存储。"

3.1 数组下标规律

完全二叉树的特性

数组:[10, 8, 9, 5, 6, 7, 4]
索引: 0  1  2  3  4  5  6

对应的树:
      10 (index=0)
     /  \
    8    9 (index=1,2)
   / \  / \
  5  6 7  4 (index=3,4,5,6)

规律:
父节点index=i
左子节点index=2*i+1
右子节点index=2*i+2

子节点index=i
父节点index=(i-1)/2

哈吉米:"用数组存树,省空间!"


四、Top K问题的核心技巧

阿西噶阿西:"重点来了:求第K大用小根堆,求第K小用大根堆。"

哈吉米:"为啥反着来?"

4.1 为什么求第K大用小根堆

南北绿豆:"想象选班级前K名。"

场景:班级50人,选前3名(K=3)

方法1:用大根堆(错误思路)
  → 把所有50人加入大根堆
  → pop 3次,得到前3名
  → 时间O(nlogn),空间O(n)

方法2:用小根堆(正确思路)
  → 维护一个大小为3的小根堆
  → 遍历50人,如果比堆顶大,替换堆顶
  → 最后堆里的3个人就是前3名
  → 时间O(nlogk),空间O(k)

示例nums = [3,2,1,5,6,4], k = 2(找第2大)

小根堆过程

遍历当前数堆(小根堆,容量2)操作
13[3]堆未满,直接加
22[2,3]堆未满,直接加
31[2,3]1<堆顶2,不加
45[3,5]5>堆顶2,移除2,加5
56[5,6]6>堆顶3,移除3,加6
64[5,6]4<堆顶5,不加

最终:堆顶5就是第2大的元素

哈吉米:"懂了!小根堆保证堆顶是第K大,堆里其他元素都≥堆顶。"

阿西噶阿西:"对,小根堆的堆顶是堆里最小的,也就是第K大的。"


五、例题1:数组中的第K个最大元素

5.1 代码实现

Java版本

public int findKthLargest(int[] nums, int k) {
    // 小根堆(Java默认是小根堆)
    PriorityQueue<Integer> heap = new PriorityQueue<>();
    
    for (int num : nums) {
        heap.offer(num);
        
        // 堆大小超过k,移除堆顶(最小的)
        if (heap.size() > k) {
            heap.poll();
        }
    }
    
    // 堆顶就是第K大
    return heap.peek();
}

C++版本

int findKthLargest(vector<int>& nums, int k) {
    // 小根堆
    priority_queue<int, vector<int>, greater<int>> heap;
    
    for (int num : nums) {
        heap.push(num);
        
        if (heap.size() > k) {
            heap.pop();
        }
    }
    
    return heap.top();
}

Python版本

def findKthLargest(nums, k):
    import heapq
    
    heap = []
    
    for num in nums:
        heapq.heappush(heap, num)  # Python默认小根堆
        
        if len(heap) > k:
            heapq.heappop(heap)
    
    return heap[0]

时间复杂度:O(nlogk)


六、例题2:前K个高频元素

6.1 题目

LeetCode 347 - 前 K 个高频元素

给你一个整数数组 nums 和一个整数 k ,请你返回其中出现频率前 k 高的元素。

示例:
输入:nums = [1,1,1,2,2,3], k = 2
输出:[1,2]

输入:nums = [1], k = 1
输出:[1]

6.2 思路分析

南北绿豆:"这题分两步:统计频率 + Top K。"

思路

  1. 用HashMap统计每个元素的频率
  2. 用小根堆维护频率前K高的元素(按频率比较)

为什么用小根堆?

阿西噶阿西:"因为要找频率最高的K个,相当于找频率第K大,用小根堆。"

6.3 代码实现

Java版本

public int[] topKFrequent(int[] nums, int k) {
    // 统计频率
    Map<Integer, Integer> count = new HashMap<>();
    for (int num : nums) {
        count.put(num, count.getOrDefault(num, 0) + 1);
    }
    
    // 小根堆:按频率排序
    PriorityQueue<int[]> heap = new PriorityQueue<>((a, b) -> a[1] - b[1]);
    
    for (Map.Entry<Integer, Integer> entry : count.entrySet()) {
        heap.offer(new int[]{entry.getKey(), entry.getValue()});
        
        if (heap.size() > k) {
            heap.poll(); // 移除频率最小的
        }
    }
    
    // 提取结果
    int[] result = new int[k];
    for (int i = 0; i < k; i++) {
        result[i] = heap.poll()[0];
    }
    
    return result;
}

C++版本

vector<int> topKFrequent(vector<int>& nums, int k) {
    unordered_map<int, int> count;
    for (int num : nums) {
        count[num]++;
    }
    
    // 小根堆:按频率排序
    auto cmp = [](pair<int, int>& a, pair<int, int>& b) {
        return a.second > b.second;
    };
    priority_queue<pair<int, int>, vector<pair<int, int>>, decltype(cmp)> heap(cmp);
    
    for (auto& p : count) {
        heap.push(p);
        
        if (heap.size() > k) {
            heap.pop();
        }
    }
    
    vector<int> result;
    while (!heap.empty()) {
        result.push_back(heap.top().first);
        heap.pop();
    }
    
    return result;
}

Python版本

def topKFrequent(nums, k):
    from collections import Counter
    import heapq
    
    count = Counter(nums)
    
    # 小根堆:存储(频率, 元素)
    heap = []
    
    for num, freq in count.items():
        heapq.heappush(heap, (freq, num))
        
        if len(heap) > k:
            heapq.heappop(heap)
    
    return [num for freq, num in heap]

七、例题3:数据流的中位数(双堆)

7.1 题目

LeetCode 295 - 数据流的中位数(Hard)

中位数是有序整数列表中的中间值。如果列表的大小是偶数,则没有中间值,中位数是两个中间值的平均值。

实现 MedianFinder 类:
- MedianFinder() 初始化对象。
- void addNum(int num) 将数据流中的整数 num 添加到数据结构中。
- double findMedian() 返回到目前为止所有元素的中位数。

示例:
输入:
["MedianFinder", "addNum", "addNum", "findMedian", "addNum", "findMedian"]
[[], [1], [2], [], [3], []]
输出:
[null, null, null, 1.5, null, 2.0]

解释:
MedianFinder medianFinder = new MedianFinder();
medianFinder.addNum(1);    // arr = [1]
medianFinder.addNum(2);    // arr = [1, 2]
medianFinder.findMedian(); // 返回 1.5 ((1 + 2) / 2)
medianFinder.addNum(3);    // arr = [1, 2, 3]
medianFinder.findMedian(); // 返回 2.0

7.2 思路分析

南北绿豆:"这题是堆的经典应用:用两个堆维护中位数。"

核心思想

  • 小根堆:存储较大的一半(堆顶是较大一半中最小的)
  • 大根堆:存储较小的一半(堆顶是较小一半中最大的)
  • 中位数:就在两个堆顶之间

图示

数据:[1, 2, 3, 4, 5]

大根堆(较小一半):[1, 2, 3],堆顶3
小根堆(较大一半):[4, 5],堆顶4

中位数 = 3(或(3+4)/2,取决于总数奇偶)

结构图

flowchart LR
    A["大根堆<br/>较小一半<br/>堆顶=max"]
    B["中位数<br/>在两个堆顶之间"]
    C["小根堆<br/>较大一半<br/>堆顶=min"]
    
    A --> B
    B --> C
    
    style B fill:#e1ffe1

维护规则

  1. 两个堆的大小差不超过1
  2. 大根堆的size ≥ 小根堆的size(多出来的元素放大根堆)

哈吉米:"两个堆的堆顶就是中间的两个数!"

7.3 代码实现

Java版本

class MedianFinder {
    private PriorityQueue<Integer> maxHeap; // 大根堆:存较小一半
    private PriorityQueue<Integer> minHeap; // 小根堆:存较大一半
    
    public MedianFinder() {
        maxHeap = new PriorityQueue<>((a, b) -> b - a); // 大根堆
        minHeap = new PriorityQueue<>(); // 小根堆
    }
    
    public void addNum(int num) {
        // 先加到大根堆
        maxHeap.offer(num);
        
        // 把大根堆的最大值移到小根堆(保证大根堆所有元素≤小根堆所有元素)
        minHeap.offer(maxHeap.poll());
        
        // 平衡两个堆的大小(大根堆size >= 小根堆size)
        if (maxHeap.size() < minHeap.size()) {
            maxHeap.offer(minHeap.poll());
        }
    }
    
    public double findMedian() {
        if (maxHeap.size() > minHeap.size()) {
            // 总数是奇数,大根堆多一个元素
            return maxHeap.peek();
        } else {
            // 总数是偶数,取两个堆顶的平均值
            return (maxHeap.peek() + minHeap.peek()) / 2.0;
        }
    }
}

C++版本

class MedianFinder {
private:
    priority_queue<int> maxHeap; // 大根堆
    priority_queue<int, vector<int>, greater<int>> minHeap; // 小根堆
    
public:
    MedianFinder() {}
    
    void addNum(int num) {
        maxHeap.push(num);
        
        minHeap.push(maxHeap.top());
        maxHeap.pop();
        
        if (maxHeap.size() < minHeap.size()) {
            maxHeap.push(minHeap.top());
            minHeap.pop();
        }
    }
    
    double findMedian() {
        if (maxHeap.size() > minHeap.size()) {
            return maxHeap.top();
        } else {
            return (maxHeap.top() + minHeap.top()) / 2.0;
        }
    }
};

Python版本

class MedianFinder:
    def __init__(self):
        self.maxHeap = []  # 大根堆(取负数模拟)
        self.minHeap = []  # 小根堆
    
    def addNum(self, num):
        import heapq
        
        # 先加到大根堆(取负数)
        heapq.heappush(self.maxHeap, -num)
        
        # 把大根堆最大值移到小根堆
        heapq.heappush(self.minHeap, -heapq.heappop(self.maxHeap))
        
        # 平衡两个堆
        if len(self.maxHeap) < len(self.minHeap):
            heapq.heappush(self.maxHeap, -heapq.heappop(self.minHeap))
    
    def findMedian(self):
        if len(self.maxHeap) > len(self.minHeap):
            return -self.maxHeap[0]
        else:
            return (-self.maxHeap[0] + self.minHeap[0]) / 2.0

7.4 执行过程演示

示例:依次添加 [1, 2, 3]

操作大根堆小根堆中位数
初始[][]-
add(1)[1][]1
add(2)[1][2]1.5
add(3)[2,1][3]2

详细过程

add(1):
  maxHeap.add(1) → [1]
  minHeap.add(maxHeap.poll()) → minHeap=[1], maxHeap=[]
  平衡:maxHeap.add(minHeap.poll()) → maxHeap=[1], minHeap=[]
  
add(2):
  maxHeap.add(2) → [2,1](大根堆,堆顶2)
  minHeap.add(maxHeap.poll()) → minHeap=[2], maxHeap=[1]
  不需要平衡(size相等)
  
add(3):
  maxHeap.add(3) → [3,1]
  minHeap.add(maxHeap.poll()) → minHeap=[2,3], maxHeap=[1]
  平衡:maxHeap.add(minHeap.poll()) → maxHeap=[2,1], minHeap=[3]

哈吉米:"两个堆配合,巧妙维护中位数。"


八、堆与Top K总结

8.1 核心技巧

南北绿豆总结:

问题使用的堆原因
前K大小根堆(容量K)堆顶是第K大,踢掉比它小的
前K小大根堆(容量K)堆顶是第K小,踢掉比它大的
中位数大根堆+小根堆两个堆顶夹着中位数

8.2 时间复杂度

阿西噶阿西

操作复杂度说明
插入O(logn)堆调整
删除堆顶O(logn)堆调整
查看堆顶O(1)直接返回
Top KO(nlogk)n个元素,堆大小k

8.3 识别技巧

南北绿豆

  • 看到第K大/小、前K个,想堆
  • 看到中位数、动态数据流,想双堆
  • 看到优先级、最值,想堆

8.4 通用模板

Java版本

// 求前K大:用小根堆
public int[] topK(int[] nums, int k) {
    PriorityQueue<Integer> heap = new PriorityQueue<>();
    
    for (int num : nums) {
        heap.offer(num);
        if (heap.size() > k) {
            heap.poll();
        }
    }
    
    int[] result = new int[k];
    for (int i = 0; i < k; i++) {
        result[i] = heap.poll();
    }
    return result;
}

C++版本

vector<int> topK(vector<int>& nums, int k) {
    priority_queue<int, vector<int>, greater<int>> heap;
    
    for (int num : nums) {
        heap.push(num);
        if (heap.size() > k) {
            heap.pop();
        }
    }
    
    vector<int> result;
    while (!heap.empty()) {
        result.push_back(heap.top());
        heap.pop();
    }
    return result;
}

Python版本

def topK(nums, k):
    import heapq
    
    heap = []
    
    for num in nums:
        heapq.heappush(heap, num)
        if len(heap) > k:
            heapq.heappop(heap)
    
    return list(heap)

参考资料

  • 《算法第四版》- Robert Sedgewick
  • 《算法导论》- Thomas H. Cormen
  • LeetCode题解 - 堆专题