简单说说堆排序

175 阅读3分钟

堆满足下列性质:

  1. 堆中某个节点的值总是不大于或不小于其父节点的值
  2. 堆总是一颗完全二叉树。

由以上性质可得,堆可以只用数组而不需要指针就可以表示,具体方法是将二叉树的节点按照层级顺序放入数组中,根节点在位置1,它的子节点在位置2、3,而子节点的子节点分别在位置4、5、6、7,依此类推。

堆排序主要分成两个阶段:

  1. 构造堆

    堆的构造是将子堆从小子堆到大子堆依次进行有序化的过程。

    堆有序化有两种方法:

    • 由下至上的堆有序化

        private void swim(Comparable[] input, int targetIndex, int startIndex) {
            while (targetIndex / 2 >= startIndex && less(input, targetIndex / 2, targetIndex)) {
                exch(input, targetIndex / 2, targetIndex);
                targetIndex = targetIndex / 2;
            }
        }
      
    • 由上至下的堆有序化

          private void sink(Comparable[] input, int targetIndex, int endIndex) {
           while (2 * targetIndex <= endIndex) {
               int leftSon = 2 * targetIndex;
               int rightSon = leftSon + 1;
               int son = leftSon;
               //找出左右子节点中大的那个用来比较并且右子节点不能多于endIndex
               if (less(input, leftSon, rightSon) && rightSon <= endIndex) {
                   son = rightSon;
               }
               //如果当前节点大于大的子节点,则当前位置正确。
               if (!less(input, targetIndex, son)) {
                   break;
               }
               exch(input, targetIndex, son);
               targetIndex = son;
      
           }
       }
      

    因此,堆的构造也有两种方式:

    • 由顶部到底部的构造堆(使用由下到上的堆有序化)
      for (int i = 1; i <= input.length && (2 * i <= input.length || 2 * i + 1 <= input.length); i++) {
                if (2 * i <= input.length) {
                    swim(input, 2 * i, 1);
                }
                if (2 * i + 1 <= input.length) {
                    swim(input, 2 * i + 1, 1);
                }
            }
      
    • 由底部到顶部的构造堆(使用由上往下的堆有序化)
      for (int i = input.length / 2; i >= 1; i--) {
                sink(input, i, input.length);
            }
      

    容易得出,由底部到顶部的构造堆更加高效。

  2. 排序

    根据堆的特性可知,顶部(位置 1)为最大值(最小值),我们将位置 1 与最后一位进行交换,然后将位置 1 到倒数第二位的堆进行位置 1 的堆有序化,依此循环,即可得到有序的数组。

     while (endIndex > 1) {
         exch(input, 1, endIndex);
         endIndex--;
         sink(input, 1, endIndex);
     }
    

完整代码:

    public static class HeapSort {
        private boolean less(Comparable[] input, int startIndex, int endIndex) {
            return input[startIndex - 1].compareTo(input[endIndex - 1]) < 0;
        }

        private void exch(Comparable[] input, int startIndex, int endIndex) {
            Comparable temp = input[startIndex - 1];
            input[startIndex - 1] = input[endIndex - 1];
            input[endIndex - 1] = temp;
        }

        private void swim(Comparable[] input, int targetIndex, int startIndex) {
            while (targetIndex / 2 >= startIndex && less(input, targetIndex / 2, targetIndex)) {
                exch(input, targetIndex / 2, targetIndex);
                targetIndex = targetIndex / 2;
            }
        }

        private void sink(Comparable[] input, int targetIndex, int endIndex) {
            while (2 * targetIndex <= endIndex) {
                int leftSon = 2 * targetIndex;
                int rightSon = leftSon + 1;
                int son = leftSon;
                //找出左右子节点中大的那个用来比较并且右子节点不能多于endIndex
                if (less(input, leftSon, rightSon) && rightSon <= endIndex) {
                    son = rightSon;
                }
                //如果当前节点大于大的子节点,则当前位置正确。
                if (!less(input, targetIndex, son)) {
                    break;
                }
                exch(input, targetIndex, son);
                targetIndex = son;

            }
        }

        public void sort(Comparable[] input) {
            //构造堆
            for (int i = input.length / 2; i >= 1; i--) {
                sink(input, i, input.length);
            }
            int endIndex = input.length;
            //排序
            while (endIndex > 1) {
                exch(input, 1, endIndex);
                endIndex--;
                sink(input, 1, endIndex);
            }
        }

        public void toHeap(Comparable[] input) {
            for (int i = 1; i <= input.length && (2 * i <= input.length || 2 * i + 1 <= input.length); i++) {
                if (2 * i <= input.length) {
                    swim(input, 2 * i, 1);
                }
                if (2 * i + 1 <= input.length) {
                    swim(input, 2 * i + 1, 1);
                }
            }
        }
    }