快速排序的五种优化思路

209 阅读4分钟

(主要是作为个人笔记的一篇文章,几乎全是代码,讲解非常少,没兴趣的同学请直接略过)

说起快速排序,相信很多同学都不会很陌生,这个 20 世纪最伟大的算法之一,有着非常巧妙的设计思想和优越的性能,其中最主要的是它的切分思想(partition)也是面试题中常考的,比如荷兰国旗查找第 K大的数

我们今天的重点并不是讲解快速排序,而是去研究一下快速排序的几种优化思路,所以,在这里假定大家都已经了解基础版的快排是如何实现的了。但在看几种优化之前,我们还是给出基本的实现:

首先我们给出通用的模板:

package utils;

public class ArrayUtils {
    public static boolean less(double[] arr, int i, int j) {
        return arr[i] < arr[j];
    }

    public static void swap(double[] arr, int i, int j) {
        double temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

}
public static void sort(double[] arr) {
    sort(arr, 0, arr.length - 1);
}

public static void sort(double[] arr, int lo, int hi) {
    if (lo >= hi) {
        return;
    }

    int j = partition(arr, lo, hi);
    sort(arr, lo, j - 1);
    sort(arr, j + 1, hi);
}

public static int partition(double[] a, int lo, int hi) {
    return null;
}

下面是我们最基础的 partition 的实现方法:

public static int partition(double[] a, int lo, int hi) {
    int i = lo, j = hi + 1;

    while (true) {
        while(less(a, ++i, lo)) {
            if (i == hi) {
                break;
            }
        }

        while (ArrayUtils.less(a, lo, --j)) {
            if (j == lo) {
                break;
            }
        }

        if (i >= j) {
            break;
        }

        ArrayUtils.swap(a, i, j);
    }

    ArrayUtils.swap(a, j, lo);

    return j;
}

移除掉 partition 中的比较操作

partition 函数中,有下面这段代码:

 while(less(a, ++i, lo)) {
    if (i == hi) {
        break;
    }
}

while (less(a, lo, --j)) {
    if (j == lo) {
        break;
    }
}

我们仔细的分析一下, 其实上一段代码中第二个 while 循环内部的逻辑是不需要的,当 j 减小到等于 lo 的时候,肯定就中止了,同时,如果我们想让第一个 while 循环也不走内部的逻辑,可以先把整个数组的最大值放到数组的最右边,那它也不需要比较内部了,因为肯定不会越界。

比起初级版,我们有两个函数需要改动一下:

public static void sort(double[] arr) {

    int maxIndex = 0;
    for (int i = 1; i < arr.length; i++) {
        if (arr[i] > arr[maxIndex]) {
            maxIndex = i;
        }
    }

    ArrayUtils.swap(arr, maxIndex, arr.length - 1);
    sort(arr, 0, arr.length - 1);
}
public static int partition(double[] a, int lo, int hi) {
    int i = lo, j = hi + 1;

    while (true) {
        while(ArrayUtils.less(a, ++i, lo));

        while(ArrayUtils.less(a, lo, --j));

        if (i >= j) {
            break;
        }

        ArrayUtils.swap(a, i, j);
    }

    ArrayUtils.swap(a, j, lo);

    return j;
}

使用插入排序

我们是使用递归的方式去实现的,在数组长度很小的时候使用递归是不划算的,我们想到的最简单的优化就是在子数组长度短的时候使用插入排序。

public static void sort(double[] arr, int lo, int hi) {
    if (hi - lo <= 16) {
        InsertionSort.sort(arr, lo, hi); 
        return;
    }

    int j = partition(arr, lo, hi);
    sort(arr, lo, j-1);
    sort(arr, j + 1, hi);
}
public class InsertionSort {
    public static void sort(double[] arr) {
        sort(arr, 0, arr.length - 1);
    }

    public static void sort(double[] arr, int lo, int hi) {
        for (int i = lo + 1; i <= hi; i++) {
            int j = i - 1;
            double temp = arr[i];
            while(j >= 0 && arr[j] > temp) {
                arr[j + 1] = arr[j];
                j--;
            }
            arr[j+1] = temp;
        }
    }
}

随机选择一个作为分割点

影响快排效率的一个重要原因是切分点选择不好,我们从子数组中选取三个点,在这三个点中再选一个中间值,就尽量避免了选出极端的切分点的可能性。

public static int partition(double[] arr, int lo, int hi) {
    int mid = lo + (hi - lo) / 2;

    int maxIndex = lo;
    if (ArrayUtils.less(arr, lo, mid)) {
        maxIndex = mid;
    }

    if (ArrayUtils.less(arr, maxIndex, hi)) {
        maxIndex = hi;
    }

    ArrayUtils.swap(arr, maxIndex, hi);

    if (ArrayUtils.less(arr, mid, lo)) {
        ArrayUtils.swap(arr, mid, lo);
    }

    int i = lo, j = hi;

    while(true) {
        while(ArrayUtils.less(arr, ++i, lo));

        while(ArrayUtils.less(arr, lo, --j));

        if (i >= j) {
            break;
        }

        ArrayUtils.swap(arr, i, j);
    }

    ArrayUtils.swap(arr, lo, j);

    return j;
}

三路快排

此优化方案主要是针对相同数很多的情况。

public static void sort(double[] arr) {
    sort(arr, 0, arr.length - 1);
}

private static void sort(double[] arr, int lo, int hi) {
    if (lo >= hi) {
        return;
    }

    int lt = lo, gt = hi;
    int i = lo + 1;

    // [lo, lt), [lt, gt], (gt, hi]
    double v = arr[lo];
    while (i <= gt) {
        if (arr[i] < v) {
            ArrayUtils.swap(arr, i++, lt++);
        } else if (arr[i] > v) {
            ArrayUtils.swap(arr, i, gt--);
        } else {
            i++;
        }
    }

    sort(arr, lo, lt - 1);
    sort(arr, gt + 1, hi);
}

进阶版的三路快排

这个版本也是为了解决相同数很多时快速排序比较低效的问题,于上一个版本不同的是,刚开始会把相同的元素放在两端。

public static void sort(double[] arr) {
    sort(arr, 0, arr.length - 1);
}

private static void sort(double[] arr, int lo, int hi) {
    if (hi - lo <= 16) {
        InsertionSort.sort(arr, lo, hi);
        return;
    }

    reSortMidItem(arr, lo, hi);

    int i = lo, j = hi + 1, p = lo, q = hi + 1;

    while (true) {
        if (i > lo && arr[i] == arr[lo]) {
            ArrayUtils.swap(arr, i, ++p);
        }

        if (j <= hi && arr[j] == arr[lo]) {
            ArrayUtils.swap(arr, j, --q);
        }

        while (ArrayUtils.less(arr, ++i, lo)) ;

        while (ArrayUtils.less(arr, lo, --j)) ;

        // 两个都等于 lo 时,也会在后面 break 掉,此时单独处理一下
        if (i == j && arr[i] == arr[lo]) {
            ArrayUtils.swap(arr, ++p, i);
        }

        if (i >= j) {
            break;
        }

        ArrayUtils.swap(arr, i, j);
    }


    i = j + 1;

    for (int k = lo; k <= p; k++) {
        ArrayUtils.swap(arr, k, j--);
    }

    for (int k = hi; k >= q; k--) {
        ArrayUtils.swap(arr, k, i++);
    }

    sort(arr, lo, j);
    sort(arr, i, hi);
}


private static void reSortMidItem(double[] arr, int lo, int hi) {
    int mid = lo + (hi - lo) / 2;
    int maxIndex = lo;

    if (ArrayUtils.less(arr, lo, mid)) {
        maxIndex = mid;
    }

    if (ArrayUtils.less(arr, maxIndex, hi)) {
        maxIndex = hi;
    }

    ArrayUtils.swap(arr, maxIndex, hi);

    if (ArrayUtils.less(arr, mid, lo)) {
        ArrayUtils.swap(arr, mid, lo);
    }
}