二叉堆

184 阅读9分钟

0. 开篇

设计一种数据结构,用来存放整数,要求提供 3 个接口:
√ 添加元素;
√ 获取最大值;
√ 删除最大值。
数组?数组排序获取最大值?如下图1:

image.png

当然还有我们学过的链表、平衡二叉搜索树:

image.png

还有没有更好的数据结构呢?

1. 二叉堆

1.1 堆的分类

堆(Heap)也是一种树状的数据结构(不要跟内存模型中的“堆空间”混淆),常见的堆实现有:
 二叉堆(Binary Heap,完全二叉堆)
 多叉堆(D-heap、D-ary Heap)
 索引堆(Index Heap)
 二项堆(Binomial Heap)
 斐波那契堆(Fibonacci Heap)
......

堆的一个重要性质:任意节点的值总是 ≥( ≤ )子节点的值:

  • 如果任意节点的值总是 ≥ 子节点的值,称为:最大堆、大根堆、大顶堆。
  • 如果任意节点的值总是 ≤ 子节点的值,称为:最小堆、小根堆、小顶堆。

由此可见,堆中的元素必须具备可比较性(跟二叉搜索树一样)。

我们这里只探讨 二叉堆
二叉堆的逻辑结构就是一棵完全二叉树,所以也叫完全二叉堆。它分为两种:最大堆和最小堆。

image.png

image.png

鉴于完全二叉树的一些特性,二叉堆的底层(物理结构)一般用数组实现即可。

图5,用数组实现二叉堆。 image.png

由于二叉堆的逻辑结构是一个完全二叉树,所以二叉堆也有以下性质:
◼ 索引 i 的规律( n 是元素数量)
1) 如果 i = 0 ,它是根节点
2) 如果 i > 0 ,它的父节点的索引为 floor( (i – 1) / 2 )
3) 如果 2i + 1 ≤ n – 1,它的左子节点的索引为 2i + 1
4) 如果 2i + 1 > n – 1 ,它无左子节点
5) 如果 2i + 2 ≤ n – 1 ,它的右子节点的索引为 2i + 2
6) 如果 2i + 2 > n – 1 ,它无右子节点

1.2 接口设计

二叉堆的基本接口设计如下:

 {
    int size();	// 元素的数量
    boolean isEmpty();	// 是否为空
    void clear();	// 清空
    void add(E element);	 // **添加元素**
    E get();	// 获得堆顶元素
    E remove(); // 删除堆顶元素
    E replace(E element); // 删除堆顶元素的同时插入一个新元素
}

"最大堆"和"最小堆"是对称关系。接下来分析"最大堆"。

1.3 添加元素

最大堆的任意节点值总是大于等于子节点的值,我们在添加元素的时候,将新元素添加在数组末尾,然后将新元素循环与其父节点的值做比较,如果大于父节点则与父节点交换位置,最终可恢复最大堆。

添加元素的流程图如下图6所示:

image.png

添加元素的总结

  1. 将新添加元素添加到数组末尾;

  2. 循环执行以下操作(新添加元素 简称为 node):
    1)如果 node > 父节点,则与父节点交换位置;
    2)如果 node ≤ 父节点,或者 node 没有父节点,退出循环。

    这个过程,叫做上滤(Sift Up)

添加元素的时间复杂度:O(logn)。

添加元素的代码

/**
 * 添加元素
 */
public void add(E element) {
    // 添加元素非空检查
    elementNotNullCheck(element);

    // 数组容量检查
    ensureCapacity(size + 1);

    // 元素添加到数组末尾 且 数组元素大小+1
    elements[size++] = element;

    // 调用上滤
    siftUp(size - 1);
}

/**
 * 让 index 位置的元素上滤
 * @param index
 */
private void siftUp(int index) {
    // 执行上滤的元素
    E e = elements[index];
    
    // 循环执行:1)上滤的元素 > 父节点值,交换位置;2)上滤的元素 <= 父节点值 或者无父节点,退出循环
    while (index > 0) {
        // 完全二叉树:某元素的父节点的索引:floor( (i – 1) / 2 )
        int pindex = (index - 1) >> 1;
        E p = elements[pindex];
        if (compare(e, p) <= 0) return;
        
        // 交换index、pindex位置的内容
        E tmp = elements[index];
        elements[index] = elements[pindex];
        elements[pindex] = tmp;
        
        // 重新赋值index
        index = pindex;
    }
}

上面代码中上滤代码可以稍稍优化:
每次同父节点比较值的时候,执行了交换位置及内容。其实可以将新添加节点备份,每次比较只需记录节点变化的位置,等确定最终位置才摆放上去。

优化后的上滤代码

/**
 * 让 index 位置的元素上滤
 * @param index
 */
private void siftUp(int index) {
    // 执行上滤的元素
    E element = elements[index];

    // 循环执行:1)上滤的元素 > 父节点值,父节点赋值到该上滤元素的位置,同时记录下滤的位置(即需要交换的父节点的位置);
    //          2)上滤的元素 <= 父节点值 或者无父节点,退出循环;
    //          3)退出循环后,根据最终记录的位置,赋值
    while (index > 0) {
        // 完全二叉树:某元素的父节点的索引:floor( (i – 1) / 2 )
        int parentIndex = (index - 1) >> 1;
        E parent = elements[parentIndex];
        if (compare(element, parent) <= 0) break;

        // 将父元素存储在index位置
        elements[index] = parent;

        // 重新赋值index:
        index = parentIndex;
    }
    // 都比较完成后,上滤元素找到最终位置,赋值
    elements[index] = element;
}

对比优化前后的代码,仅从交换位置的代码角度看:可以由大概的 3 * O(logn) 优化到 1 * O(logn) + 1。

1.4 删除堆顶元素

删除堆顶元素,将数组最后一个节点覆盖根节点,然后根节点循环同较大子节点比较,如果小于较大子节点,则互换位子,直至大于等于较大子节点,可最终恢复最大堆。

删除堆顶元素的流程如下图7所示:

image.png

删除堆顶元素的总结

  1. 用最后一个节点覆盖根节点;
  2. 删除最后一个节点;
  3. 循环执行以下操作(图7中的 33 简称为 node):
    1)如果 node < 最大的子节点,与最大的子节点交换位置。
    2)如果 node ≥ 最大的子节点, 或者 node 没有子节点,退出循环。

这个过程,叫做下滤(Sift Down),时间复杂度:O(logn)

同样的,交换位置的操作可以像添加那样进行优化。

删除堆顶元素代码

/**
 * 删除堆顶元素
 */
public E remove() {
    // 二叉堆空检查
    emptyCheck();

    // 最后一个元素位置(size - 1) 同时元素大小 -1
    int lastIndex = --size;

    // 删除的元素
    E root = elements[0];
    // 最后一个元素覆盖根节点
    elements[0] = elements[lastIndex];
    elements[lastIndex] = null;

    // 调用下滤方法
    siftDown(0);
    return root;
}

/**
 * 让 index 位置的元素下滤
 * @param index
 */
private void siftDown(int index) {
    // 执行下滤的元素
    E element = elements[index];

    // 下滤的前提是节点用子节点,二叉树的性质可以知道,度为1和2的节点总数等于总结点数的一半,即size / 2
    int half = size >> 1;

    // 第一个叶子节点的索引 == 非叶子节点的数量
    // index < 第一个叶子节点的索引
    // 必须保证index位置是非叶子节点
    while (index < half) { 
        // index的节点有2种情况
        // 1.只有左子节点
        // 2.同时有左右子节点

        // 默认为左子节点跟它进行比较
        int childIndex = (index << 1) + 1;
        E child = elements[childIndex];

        // 右子节点
        int rightIndex = childIndex + 1;

        // 选出左右子节点最大的那个
        if (rightIndex < size && compare(elements[rightIndex], child) > 0) {
            childIndex = rightIndex;
            child = elements[rightIndex];
        }

        // 大于等于较大子节点值,跳出循环
        if (compare(element, child) >= 0) break;

        // 将子节点存放到index位置
        elements[index] = child;
        // 重新设置index(每次都记录元素下滤的位置)
        index = childIndex;
    }
    // 都比较完成后,下滤元素找到最终位置,赋值
    elements[index] = element;
}

1.5 删除堆顶元素的同时插入一个新元素

删除堆顶元素的同时插入一个新元素的,用新元素替换堆顶元素,然后对堆顶元素执行下滤操作即可。
代码如下:

/**
 * 删除堆顶元素的同时插入一个新元素
 */
public E replace(E element) {
    // 新元素非空检查
    elementNotNullCheck(element);

    E root = null;

    // 空堆,直接赋值
    if (size == 0) {
        elements[0] = element;
        size++;

    // 新元素替换堆顶元素,然后下滤
    } else {
        root = elements[0];
        elements[0] = element;
        siftDown(0);
    }
    return root;
}

2. 最大堆的批量建堆

给定一个任意数组,创建一个对应的最大堆,创建方式可以分为 “自上而下的上滤 批量建堆 最大堆” 和 “自下而上的下滤 批量建堆 最大堆”。

2.1 自上而下的上滤 批量建堆 最大堆

自上而下的上滤:循环给定数组,并从给定数组的第2个元素(索引为1)开始的每一个元素执行上滤,循环结束后,最大堆创建完成。

详细的过程如下图8所示:

image.png

代码如下:

/**
* 批量建堆
*/
private void heapify() {
    // 自上而下的上滤
    for (int i = 1; i < size; i++) {
            siftUp(i);
    }
}

2.2 自下而上的下滤 批量建堆 最大堆

自上而下的上滤:循环给定数组,并从给定数组索引为 size / 2 - 1 的位置开始(向前遍历)的每一个元素执行下滤,循环结束后,最大堆创建完成。

size 为数组元素的个数。完全二叉树的叶子节点和非叶子节点的个数相等为总结点(即为数组的个数size)的数的 1/2,而叶子节点无需开路下滤,故从 size / 2 - 1 开始,向前遍历处理。

详细的过程如下图9所示:

image.png

代码如下:

/**
* 批量建堆
*/
private void heapify() {
    // 自下而上的下滤
    for (int i = (size >> 1) - 1; i >= 0; i--) {
            siftDown(i);
    }
}

自上而下的上滤的算法复杂度为:O(nlogn) 。
自下而上的下滤的算法复杂度为:O(n)。

因此,批量建堆时采用 自下而上的下滤 方式。

3. 完整代码

/**
* 二叉堆(最大堆)
*
* @param <E>
*/
@SuppressWarnings("unchecked")
public class BinaryHeap<E> {
    private E[] elements;
    private static final int DEFAULT_CAPACITY = 10;

    public BinaryHeap(E[] elements, Comparator<E> comparator)  {
        super(comparator);

        if (elements == null || elements.length == 0) {
            this.elements = (E[]) new Object[DEFAULT_CAPACITY];
        } else {
            size = elements.length;
            int capacity = Math.max(elements.length, DEFAULT_CAPACITY);
            this.elements = (E[]) new Object[capacity];
            for (int i = 0; i < elements.length; i++) {
                this.elements[i] = elements[i];
            }
            // 调用批量建堆
            heapify();
        }
    }

    public BinaryHeap(E[] elements)  {
        this(elements, null);
    }

    public BinaryHeap(Comparator<E> comparator) {
        this(null, comparator);
    }

    public BinaryHeap() {
        this(null, null);
    }
    
    /**
     * 清空
     */
    public void clear() {
        for (int i = 0; i < size; i++) {
            elements[i] = null;
        }
        size = 0;
    }
    
    /**
     * 添加元素
     */
    public void add(E element) {
        // 添加元素非空检查
        elementNotNullCheck(element);

        // 数组容量检查
        ensureCapacity(size + 1);

        // 元素添加到数组末尾 且 数组元素大小+1
        elements[size++] = element;

        // 调用上滤
        siftUp(size - 1);
    }

    /**
     * 获得堆顶元素
     */
    public E get() {
        emptyCheck();
        return elements[0];
    }
    
    /**
     * 删除堆顶元素
     */
    public E remove() {
        // 二叉堆空检查
        emptyCheck();

        // 最后一个元素位置(size - 1) 同时元素大小 -1
        int lastIndex = --size;

        // 删除的元素
        E root = elements[0];
        // 最后一个元素覆盖根节点
        elements[0] = elements[lastIndex];
        elements[lastIndex] = null;

        // 调用下滤方法
        siftDown(0);
        return root;
    }
    
    /**
     * 删除堆顶元素的同时插入一个新元素
     */
    public E replace(E element) {
        // 新元素非空检查
        elementNotNullCheck(element);

        E root = null;

        // 空堆,直接赋值
        if (size == 0) {
            elements[0] = element;
            size++;

        // 新元素替换堆顶元素,然后下滤
        } else {
            root = elements[0];
            elements[0] = element;
            siftDown(0);
        }
        return root;
    }

    /**
     * 批量建堆
     */
    private void heapify() {
        // 自下而上的下滤
        for (int i = (size >> 1) - 1; i >= 0; i--) {
            siftDown(i);
        }
    }

    /**
     * 让 index 位置的元素下滤
     * @param index
     */
    private void siftDown(int index) {
        // 执行下滤的元素
        E element = elements[index];

        // 下滤的前提是节点用子节点,二叉树的性质可以知道,度为1和2的节点总数等于总结点数的一半,即size / 2
        int half = size >> 1;

        // 第一个叶子节点的索引 == 非叶子节点的数量
        // index < 第一个叶子节点的索引
        // 必须保证index位置是非叶子节点
        while (index < half) { 
            // index的节点有2种情况
            // 1.只有左子节点
            // 2.同时有左右子节点

            // 默认为左子节点跟它进行比较
            int childIndex = (index << 1) + 1;
            E child = elements[childIndex];

            // 右子节点
            int rightIndex = childIndex + 1;

            // 选出左右子节点最大的那个
            if (rightIndex < size && compare(elements[rightIndex], child) > 0) {
                childIndex = rightIndex;
                child = elements[rightIndex];
            }

            // 大于等于较大子节点值,跳出循环
            if (compare(element, child) >= 0) break;

            // 将子节点存放到index位置
            elements[index] = child;
            // 重新设置index(每次都记录元素下滤的位置)
            index = childIndex;
        }
        // 都比较完成后,下滤元素找到最终位置,赋值
        elements[index] = element;
    }

    /**
     * 让 index 位置的元素上滤
     * @param index
     */
    private void siftUp(int index) {
        // 执行上滤的元素
        E element = elements[index];

        // 循环执行:1)上滤的元素 > 父节点值,父节点赋值到该上滤元素的位置,同时记录下滤的位置(即需要交换的父节点的位置);
        //          2)上滤的元素 <= 父节点值 或者无父节点,退出循环;
        //          3)退出循环后,根据最终记录的位置,赋值
        while (index > 0) {
            // 完全二叉树:某元素的父节点的索引:floor( (i – 1) / 2 )
            int parentIndex = (index - 1) >> 1;
            E parent = elements[parentIndex];
            if (compare(element, parent) <= 0) break;

            // 将父元素存储在index位置
            elements[index] = parent;

            // 重新赋值index:
            index = parentIndex;
        }
        // 都比较完成后,上滤元素找到最终位置,赋值
        elements[index] = element;
    }
    
    /**
     * 处理数组容量
     */
    private void ensureCapacity(int capacity) {
        int oldCapacity = elements.length;
        if (oldCapacity >= capacity) return;

        // 新容量为旧容量的1.5倍
        int newCapacity = oldCapacity + (oldCapacity >> 1);
        E[] newElements = (E[]) new Object[newCapacity];
        for (int i = 0; i < size; i++) {
            newElements[i] = elements[i];
        }
        elements = newElements;
    }
    
    /**
     * 二叉堆非空检查
     */
    private void emptyCheck() {
        if (size == 0) {
            throw new IndexOutOfBoundsException("Heap is empty");
        }
    }
    
    /**
     * 添加元素非空检查
     */
    private void elementNotNullCheck(E element) {
        if (element == null) {
            throw new IllegalArgumentException("element must not be null");
        }
    }
}

4. 新建最小堆

最小堆和最大堆刚好相反,因此只需在 compare 中调整返回值的计算逻辑,与默认最大堆的运算逻辑相反即可。

static void minHeap() {
    Integer[] data = {88, 44, 53, 41, 16, 6, 70, 18, 85, 98, 81, 23, 36, 43, 37};
    BinaryHeap<Integer> heap = new BinaryHeap<>(data, new Comparator<Integer>() {
        public int compare(Integer o1, Integer o2) {
            return o2 - o1;
        }
    });
    BinaryTrees.println(heap);
}

5. Top K问题

如果要从 n 个整数中,找出最大的前 k 个数( k 远远小于 n ),应该如何处理?

  • 使用排序算法进行全排序,需要 O(nlogn) 的时间复杂度。
  • 使用二叉堆来解决,可以使用 O(nlogk) 的时间复杂度来解决。
    二叉堆思路:
  1. 新建一个小顶堆;
  2. 扫描 n 个整数;
    2.1 先将遍历到的前 k 个数放入堆中;
    2.2 从第 k + 1 个数开始,如果大于堆顶元素,就使用 replace 操作(删除堆顶元素,将第 k + 1 个数添加到堆中)
  3. 扫描完毕后,堆中剩下的就是最大的前 k 个数。

代码如下:

public static void main(String[] args) {
    // 新建一个小顶堆
    BinaryHeap<Integer> minHeap = new BinaryHeap<>(new Comparator<Integer>() {
        public int compare(Integer o1, Integer o2) {
            return o2 - o1;
        }
    });

    // 找出最大的前k个数
    int k = 3;
    Integer[] data = {51, 30, 39, 92, 74, 25, 16, 93,
                    91, 19, 54, 47, 73, 62, 76, 63, 35, 18,
                    90, 6, 65, 49, 3, 26, 61, 21, 48};
    for (int i = 0; i < data.length; i++) {
        if (minHeap.size() < k) { // 前k个数添加到小顶堆
            minHeap.add(data[i]); // logk
        } else if (data[i] > minHeap.get()) { // 如果是第k + 1个数,并且大于堆顶元素
            minHeap.replace(data[i]); // logk
        }
    }
    // O(nlogk)
    BinaryTrees.println(minHeap);
}

如果要从 n 个整数中,找出最大的前 k 个数( k 远远小于 n ),则通过最大堆来实现。