题目描述:
给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。
请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。
你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。
示例 1:
输入: [3,2,1,5,6,4], k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6], k = 4
输出: 4
提示:
1 <= k <= nums.length <= 10^5-10^4 <= nums[i] <= 10^4
思路(吐槽):
这道题我知道是个可能得用堆的题目。坦白讲,除了用JDK里自带的,确实没有想法。甚至一开始连自带的那个PriorityQueue叫什么都不知道,更别说分得清它是最小堆了。所以啊,即便允许用JDK自带的类都写不出来呢!
暴力解法我是会的,Arrays.sort()然后取第k个值嘛,对不对。
时间复杂度等于O(n),了解之后才知道有个叫快速排序选择的算法。可以理论上平均是O(n)。
还有人提到堆快排,时间复杂度是O(nlogk),也蛮久的其实。
这道题我会提供以下几种解法:
- 暴力解法,使用库排序
- 使用最小堆,库类
- 手写最小堆
- 快速排序选择
接下来开始实现:
1. 暴力解法
排序后,找到第k大的值。注意这里排序是升序排序,所以从后往前找。
class Solution {
public int findKthLargest(int[] nums, int k) {
Arrays.sort(nums);
return nums[nums.length-k];
}
}
2. 库类PriorityQueue最小堆
Java中的PriorityQueue默认就是最小堆,可以直接拿来用。维护一个大小为k的最小堆,遍历数组结束后,取堆顶的就是第k大的值了。
说了这么多,问你最小堆是是啥,能回答上来吗?回答不上来看这个。最小堆就是顶部的值是堆里最小值的那种堆。
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> heap = new PriorityQueue<Integer>(); // also can add param: (a, b) -> a - b, still 最小堆 (min - heap)
for (int i : nums) {
if (heap.isEmpty() || heap.size() < k) {
heap.offer(i);
} else if (!heap.isEmpty() && i >= heap.peek()) {
heap.offer(i);
if (heap.size() > k) {
heap.poll();
}
}
}
return heap.peek();
}
}
上面是我写的,也AC了,但是DS老师写的更好,我实现一下然后贴在下面:
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> heap = new PriorityQueue<Integer>(); // also can add param: (a, b) -> a - b, still 最小堆 (min - heap)
for (int i : nums) {
if (heap.size() < k) {
heap.offer(i);
} else if (i > heap.peek()) {
heap.poll();
heap.offer(i);
}
}
return heap.peek();
}
}
插播一条,我记得当年在卧龙园找工作的时候,面试有人问过我这个问题:有4GB的数据,想要找到其中第k大的值,应该怎么找?可惜啊,5年后的我才知道答案。这种情况建议用最小堆,而不是快速排序选择。原因是数据太多了,有内存限制的话,按照流式数据来处理,最小堆合适;静态类型的数据,快速排序选择可以。
3. 手写一个最小堆
class Solution {
public int findKthLargest(int[] nums, int k) {
MinHeap m = new MinHeap(k);
for (int i: nums) {
if (m.size() < k) {
m.insert(i);
} else if (m.top() < i ) {
m.removeHead();
m.insert(i);
}
}
return m.top();
}
}
class MinHeap {
int size;
int[] nums;
MinHeap(int k) {
nums = new int[k];
}
int top() {
return nums[0];
}
int size() {
return this.size;
}
void insert(int num) {
nums[size] = num;
swim(size);
size++;
}
void removeHead() {
nums[0] = nums[size-1];
size--;
sink(0);
}
void sink(int index) {
int i = index;
int left = i*2 + 1;
while (left < size) {
int smaller = left;
if (left + 1 < size && nums[left] >= nums[left+1]) {
smaller = left + 1;
}
if (nums[i] > nums[smaller]) {
swap(i, smaller);
i = smaller;
left = smaller * 2 + 1;
} else {
break;
}
}
}
void swim(int index) {
while(index > 0 && (index-1)/2 >= 0 && nums[index] < nums[(index-1)/2]) {
swap(index, (index-1)/2);
index = (index-1)/2;
}
}
void swap(int i, int j) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
}
手写一个最小堆,需要知道以下几点:
- 四个核心方法:
- sink(int index)
- swim(int index)
- insert(int num)
- removeHead()
- 其中sink方法是要比较左右子节点的较小值,此处容易犯错
- swim方法则是寻找父节点,相对容易
手动建堆核心操作:
- 插入元素:插入到堆末尾,然后通过
swim(上浮)调整位置 - 移除堆顶:将末尾元素移到堆顶,然后通过
sink(下沉)调整位置 - 堆性质维护:父节点必须小于子节点(最小堆)
时间复杂度同方法2,都是O(nlogk)。
4. 快速选择排序
快速选择排序的思想是首先在目标区间里确定一个pivot值,然后使用三个指针,分为维护三个区间,分别是大于pivot的,等于pivot的,以及小于pivot的。随后根据k的大小与区间的长度的比较,确定k落在哪里,随后让区间变小,继续执行快速选择排序,直到找到需要的值。
实现如下:
class Solution {
public int findKthLargest(int[] nums, int k) {
return quickSortSelect(nums, 0, nums.length-1, k);
}
private int quickSortSelect(int[] nums, int left, int right, int k) {
if (left == right) {
return nums[left];
}
Random r = new Random();
int p = r.nextInt(right - left + 1) + left;
int pivot = nums[p];
int i = left;
int lt = left;
int gt = right;
// [left, x, x, x, x, right]
while (i <= gt) {
if (nums[i] > pivot) { // another mistake here nums[p]
swap(nums, lt, i);
lt++;
i++;
} else if (nums[i] < pivot) {
swap(nums, gt, i);
gt--;
} else {
i++;
}
}
int largeLen = lt - left;
int mid = gt - lt + 1;
if (k <= largeLen) {
return quickSortSelect(nums, left, lt -1, k);
} else if (k <= largeLen + mid) {
return pivot;
}
return quickSortSelect(nums, gt+1, right, k-(largeLen + mid)); // made mistake here k-gt-1
}
private void swap(int[] nums, int i, int j) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
}
思想还是很简单的,但是实现起来很容易犯错误,必须要注意以下几点:
- 区间是是闭区间
- pivot的值要确定下来,而不是用nums[p]来维护
- k落在[大,等,小]中的小区间的时候,注意k的计算应该是减去前面两个区间的长度
纸上得来终觉浅,绝知此事要躬行。