算法之快速排序

286 阅读7分钟

基本思路

比如一个数组为

int [] arr = new int []{4,6,2,3,1,5,7,8}

先从该数组中选一个数,将它放到它原本的位置,什么意思呢?就是比如我们选择 4,我们必须保证 4 之前的所有数都比 4 小,4 之后的所有数都比 4 大,这时 4 就是在它排序后的正确位置。
就像这样:

2,3,1,4,6,5,7,8

然后 4 之前数和 4 之后的数再使用同样的方法进行排序。逐渐递归。

如何将上述的 4 移到正确的位置就是快速排序的核心。 这个过程叫做 partition(分割)

通常使用数组的第一个数作为分割数组的标志点。然后通过遍历数组将数组分割为 2 部分。如图:

用 l 记录第一个数的索引,j 记录 < v 的最后一个数的索引,用 i 记录当前数的索引。 这样就满足下面的关系:

[ arr[l+1], arr[j] ] < v

[ arr[j+1], arr[i-1]] > v

对于当前位置 a[i] = e; 它有 2 种情况:

  1. e > v
    直接并入到 >v 区间的后面,,这时 i 后移一个位置,需要把 i 加 1 :i++

  2. e < v
    只需要将 e 和 > v 的第一个位置交换位置,这时 j 多了一个,需要把 j 加 1 :j++

最后将 l 位置的数和 j 位置的数调换一下,就完成了排序。

基础版本代码实现

    private static void sort(Comparable[] arr, int l, int r) {

        if (l >= r) return;

        // 返回一个中间数的索引
        int p = partition(arr, l, r);
        sort(arr, l, p - 1);
        sort(arr, p + 1, r);

    }
    // 对 arr[l...r] 部分进行分割操作
    // 返回一个索引 p,使得 arr[l...p-1] < arr[p] < arr[p+1...r]
    @SuppressWarnings("unchecked")
    private static int partition(Comparable[] arr, int l, int r) {
        Comparable v = arr[l];
        int j = l;

        // 通过循环使得 arr[l...p-1] < arr[p] < arr[p+1...i]
        for (int i = l + 1; i < r + 1; i++) {
            // arr[i] < v 时,挪到 v 左边
            if (arr[i].compareTo(v) < 0) {
                // 交换 i 和 j+1 的位置,然后将 j+1
                swap(arr, ++j, i);
            }
        }
        // 将 l 位置和 j 位置换一下顺序
        swap(arr, l, j);
        return j;
    }

随机化优化

当数据量较小的时候,使用插入排序来优化。

对于近乎有序的数组来说,基础版本的快排比归并排序慢很多,这是为什么?

归并排序每次都将一个数组平均一分为二。

快速排序虽然也是一分为二,但是不是平均分。分出来的数组一大一小。继续分下去也是这样。

我们不能保证快排的这颗二叉树的高度是 logn, 但是归并排序却可以保证。

快排的最差情况是当数组为完全的有序数组时。

右边的数组都比 v 大,所以每次只能分一个数组,这样这颗树就会有 n 层。每层的操作又会有 O(n) 的复杂度,就会产生 0(n^2) 的复杂度。

问题的原因就是我们现在选择的是最左边的元素作为中间分割元素,我们不知道它在数组中处于什么位置,导致分割的数组长度不确定,甚至只能分割成一个数组。

我们现在只要去随机选择这个 v。这种情况下,O(n^2) 几乎不存在。

    @SuppressWarnings("unchecked")
    private static void sort(Comparable[] arr, int l, int r) {

//        if (l >= r) return;
        // 数据量小时,使用快排
        if (r - l < 15) {
            InsertionSortAdvanced.sort(arr);
            return;
        }
        // 返回一个中间数的索引
        int p = partition(arr, l, r);
        sort(arr, l, p - 1);
        sort(arr, p + 1, r);

    }
    // 对 arr[l...r] 部分进行分割操作
    // 返回一个索引 p,使得 arr[l...p-1] < arr[p] < arr[p+1...r]
    @SuppressWarnings("unchecked")
    private static int partition(Comparable[] arr, int l, int r) {
        int vIndex = new Random().nextInt(r + 1 - l) + l;
        swap(arr, vIndex, l);
        Comparable v = arr[l];
        int j = l;

        // 通过循环使得 arr[l...p-1] < arr[p] < arr[p+1...i]
        for (int i = l + 1; i < r + 1; i++) {
            // arr[i] < v 时,挪到 v 左边
            if (arr[i].compareTo(v) < 0) {
                // 交换 i 和 j+1 的位置,然后将 j+1
                swap(arr, ++j, i);
            }
        }
        // 将 l 位置和 j 位置换一下顺序
        swap(arr, l, j);
        return j;
    }

双路快速排序法

如果一个百万级别的数组中含有大量的重复元素。上面的快排会很慢。为什么呢?

大量的重复元素导致分割完的 2 个数组,非常不平衡。要么是 <=v 的部分特别大,要么是 >=v 的部分特别大。 这时我们的排序算法会退化到 O(n^2) 的复杂度。

换一个新的思路,采用新的 partition 方案:

<=v>=v 放在数组的两端。两端都带 = 是因为这样左右同时判断 = 的情况就不会出现相等的数集中一端,而是近似平均的分布在 2 端。

从 i 这个位置开始向右扫描,当前元素是 <=v 的时候就继续右移扫描。直到碰到 >v 的数时就停止右移。

j 位置也一样,从 j 向左移动扫描,当前元素是 >=v 时就继续左移,知道碰到 <v 的时候就停止。

当 i 和 j 都停止的时候,i 位置的数满足 arr[i] > v,j 位置的数满足 arr[j] < v,此时将 i 和 j 位置的数交换位置。使得 v 右边的数满足 arr[0...j] < varr[i...r] < v。 这样数组就排序完了。

修改后的代码:

    /**
     * 双路同时进行分割
     */
    @SuppressWarnings("unchecked")
    private static int[] partition2(Comparable[] arr, int l, int r) {
        int vIndex = new Random().nextInt(r + 1 - l) + l;
        swap(arr, vIndex, l);
        Comparable v = arr[l];
        // j 和 k 初始值是数组的左右两端
        int j = l + 1;
        int k = r;

        // <=v 部分的循环
        for (int i = l + 1; i < k; i++) {
            // arr[i] <= v 时,向右移动,> v 时停止
            if (arr[i].compareTo(v) > 0) break;
            else j++;
        }

        // >=v 部分的循环
        for (int i = r; i > j; i--) {
            if (arr[i].compareTo(v) < 0) break;
            else k--;
        }

        // 将 k 位置和 j 位置换一下顺序
        swap(arr, k, j);
        return new int[]{j, k};
    }
  @SuppressWarnings("unchecked")
    private static void sort(Comparable[] arr, int l, int r) {

//        if (l >= r) return;
        // 数据量小时,使用快排
        if (r - l < 15) {
            InsertionSortAdvanced.sort(arr);
            return;
        }
        // 返回一对数的索引,0 位置是 <=v 的最右端,1 位置是 >=v 的最左端
        int[] p = partition2(arr, l, r);
        sort(arr, l, p[0]);
        sort(arr, p[1], r);
    }

三路快速排序(Quick Sort 3 Ways)

三路排序和之前的思想类似,把数组分为了 < v== v> v 三部分。

  • l : 中间数 v 的位置,这里指定的是最左边的位置
  • lt : < v 的最后一个位置
  • i : 待选数的位置
  • gt : > v 的第一个位置
  • r : 最右边的位置
  1. e == v

e 直接并入 == v 的部分,然后 i++ 右移

  1. e < v

将 e 和 == v 区间的第一个数交换位置, lt++, i++

  1. e > v

将 e 和 gt-1 位置的数互换位置,gt--,i++

最终会变成这样:

gt 和 i 重合的时候就是结束的时候。

此时将 v< v 的最后一个数互换位置,就完成了一次 partition。

代码如下:

    @SuppressWarnings("unchecked")
    private static void sort(Comparable[] arr, int l, int r) {
        // 数据量小时,使用快排
        if (r - l < 15) {
            InsertionSortAdvanced.sort(arr);
            return;
        }
        int[] p = partition3(arr, l, r);
        int lt = p[0];
        int gt = p[1];
        sort(arr, l, lt-1);
        sort(arr, gt, r);
    }
    @SuppressWarnings("unchecked")
    private static int[] partition3(Comparable[] arr, int l, int r) {
        int vIndex = new Random().nextInt(r + 1 - l) + l;
        swap(arr, vIndex, l);
        Comparable v = arr[l];

        int lt = l; // arr[l+1...lt] < v
        int gt = r + 1; // arr[gt...r] > v
        int i = l + 1; // arr[lt+1...i) == v

        while (i < gt) {
            // arr[i] < v,将 e 和 == v 区间的第一个数(arr[lt+1])交换位置, lt++, i++
            if (arr[i].compareTo(v) < 0) {
                swap(arr, i, lt + 1);
                i++;
                lt++;
            }

            // arr[i] > v,将 e 和 gt-1 位置的数互换位置
            else if (arr[i].compareTo(v) > 0) {
                swap(arr, i, gt - 1);
                gt--;
                // 这里 i 不需要 +1,因为交换后的 arr[gt-1] 这个数也需要被判断

            }

            // arr[i] == v
            else {
                i++;
            }
        }

        swap(arr, l, lt);
        return new int[]{lt, gt};
    }

参考

算法与数据结构