需求:从N个整数中,找出最大的前K个数(K远小于N)
很明显我们可以用二叉堆来解决,对于二叉堆不熟悉的可以参考我之前的文章:数据结构与算法-二叉堆
解决思路:
- 新建一个小顶堆
- 扫描N个整数 -- 把前K个数放入堆中 -- 从第K+1个元素开始,如果大于堆顶元素,则删除堆顶元素,再插入新的元素
- 扫描完N个整数之后,堆中的元素就是最大的前K个数
/**
* 小顶堆
*
* @author jun.zhang6
* @date 2020/9/21
*/
public class SmallBinaryHeap {
public static void main(String[] args) {
int[] arrays = {63, 57, 5, 22, 61, 74, 23, 58, 50, 33, 138, 52, 72, 13, 233, 86, 80, 43, 96, 90, 66, 26, 189, 76, 32};
int k = 3;
SmallBinaryHeap heap = new SmallBinaryHeap();
for (int i = 0; i < arrays.length; i++) {
if (i < k) {
heap.add(arrays[i]);
} else {
if (heap.get() < arrays[i]) {
heap.replace(arrays[i]);
}
}
}
}
private int[] elements;
private int size;
public SmallBinaryHeap() {
this.elements = new int[10];
}
public void add(int element) {
ensureCapacity(size + 1);
elements[size++] = element;
siftUp(size - 1);
}
public void replace(int element) {
if (size == 0) {
elements[0] = element;
size++;
} else {
elements[0] = element;
siftDown(0);
}
}
public int get() {
return elements[0];
}
private void siftDown(int index) {
int value = elements[index];
int halfIndex = size / 2;
while (index < halfIndex) {
int leftIndex = 2 * index + 1;
int leftValue = elements[leftIndex];
int rightIndex = leftIndex + 1;
int rightValue = elements[rightIndex];
if (rightIndex < size) {
if (rightValue < leftValue) {
leftValue = rightValue;
leftIndex = rightIndex;
}
}
if (leftValue > value) {
break;
}
elements[index] = leftValue;
index = leftIndex;
}
elements[index] = value;
}
private void siftUp(int index) {
int value = elements[index];
while (index > 0) {
int parentIndex = (index - 1) / 2;
int parentValue = elements[parentIndex];
if (parentValue <= value) {
break;
}
elements[index] = parentValue;
index = parentIndex;
}
elements[index] = value;
}
private void ensureCapacity(int expectCapacity) {
int oldCapacity = elements.length;
if (expectCapacity > oldCapacity) {
int[] newElements = new int[oldCapacity << 1];
for (int i = 0; i < oldCapacity; i++) {
newElements[i] = elements[i];
}
elements = newElements;
}
}
}
总结:使用小顶堆来解决Top K问题,时间复杂度为:O(logK)