(主要是作为个人笔记的一篇文章,几乎全是代码,讲解非常少,没兴趣的同学请直接略过)
说起快速排序,相信很多同学都不会很陌生,这个 20 世纪最伟大的算法之一,有着非常巧妙的设计思想和优越的性能,其中最主要的是它的切分思想(partition)也是面试题中常考的,比如荷兰国旗和查找第 K大的数。
我们今天的重点并不是讲解快速排序,而是去研究一下快速排序的几种优化思路,所以,在这里假定大家都已经了解基础版的快排是如何实现的了。但在看几种优化之前,我们还是给出基本的实现:
首先我们给出通用的模板:
package utils;
public class ArrayUtils {
public static boolean less(double[] arr, int i, int j) {
return arr[i] < arr[j];
}
public static void swap(double[] arr, int i, int j) {
double temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
public static void sort(double[] arr) {
sort(arr, 0, arr.length - 1);
}
public static void sort(double[] arr, int lo, int hi) {
if (lo >= hi) {
return;
}
int j = partition(arr, lo, hi);
sort(arr, lo, j - 1);
sort(arr, j + 1, hi);
}
public static int partition(double[] a, int lo, int hi) {
return null;
}
下面是我们最基础的 partition
的实现方法:
public static int partition(double[] a, int lo, int hi) {
int i = lo, j = hi + 1;
while (true) {
while(less(a, ++i, lo)) {
if (i == hi) {
break;
}
}
while (ArrayUtils.less(a, lo, --j)) {
if (j == lo) {
break;
}
}
if (i >= j) {
break;
}
ArrayUtils.swap(a, i, j);
}
ArrayUtils.swap(a, j, lo);
return j;
}
移除掉 partition 中的比较操作
在 partition
函数中,有下面这段代码:
while(less(a, ++i, lo)) {
if (i == hi) {
break;
}
}
while (less(a, lo, --j)) {
if (j == lo) {
break;
}
}
我们仔细的分析一下, 其实上一段代码中第二个 while
循环内部的逻辑是不需要的,当 j
减小到等于 lo
的时候,肯定就中止了,同时,如果我们想让第一个 while
循环也不走内部的逻辑,可以先把整个数组的最大值放到数组的最右边,那它也不需要比较内部了,因为肯定不会越界。
比起初级版,我们有两个函数需要改动一下:
public static void sort(double[] arr) {
int maxIndex = 0;
for (int i = 1; i < arr.length; i++) {
if (arr[i] > arr[maxIndex]) {
maxIndex = i;
}
}
ArrayUtils.swap(arr, maxIndex, arr.length - 1);
sort(arr, 0, arr.length - 1);
}
public static int partition(double[] a, int lo, int hi) {
int i = lo, j = hi + 1;
while (true) {
while(ArrayUtils.less(a, ++i, lo));
while(ArrayUtils.less(a, lo, --j));
if (i >= j) {
break;
}
ArrayUtils.swap(a, i, j);
}
ArrayUtils.swap(a, j, lo);
return j;
}
使用插入排序
我们是使用递归的方式去实现的,在数组长度很小的时候使用递归是不划算的,我们想到的最简单的优化就是在子数组长度短的时候使用插入排序。
public static void sort(double[] arr, int lo, int hi) {
if (hi - lo <= 16) {
InsertionSort.sort(arr, lo, hi);
return;
}
int j = partition(arr, lo, hi);
sort(arr, lo, j-1);
sort(arr, j + 1, hi);
}
public class InsertionSort {
public static void sort(double[] arr) {
sort(arr, 0, arr.length - 1);
}
public static void sort(double[] arr, int lo, int hi) {
for (int i = lo + 1; i <= hi; i++) {
int j = i - 1;
double temp = arr[i];
while(j >= 0 && arr[j] > temp) {
arr[j + 1] = arr[j];
j--;
}
arr[j+1] = temp;
}
}
}
随机选择一个作为分割点
影响快排效率的一个重要原因是切分点选择不好,我们从子数组中选取三个点,在这三个点中再选一个中间值,就尽量避免了选出极端的切分点的可能性。
public static int partition(double[] arr, int lo, int hi) {
int mid = lo + (hi - lo) / 2;
int maxIndex = lo;
if (ArrayUtils.less(arr, lo, mid)) {
maxIndex = mid;
}
if (ArrayUtils.less(arr, maxIndex, hi)) {
maxIndex = hi;
}
ArrayUtils.swap(arr, maxIndex, hi);
if (ArrayUtils.less(arr, mid, lo)) {
ArrayUtils.swap(arr, mid, lo);
}
int i = lo, j = hi;
while(true) {
while(ArrayUtils.less(arr, ++i, lo));
while(ArrayUtils.less(arr, lo, --j));
if (i >= j) {
break;
}
ArrayUtils.swap(arr, i, j);
}
ArrayUtils.swap(arr, lo, j);
return j;
}
三路快排
此优化方案主要是针对相同数很多的情况。
public static void sort(double[] arr) {
sort(arr, 0, arr.length - 1);
}
private static void sort(double[] arr, int lo, int hi) {
if (lo >= hi) {
return;
}
int lt = lo, gt = hi;
int i = lo + 1;
// [lo, lt), [lt, gt], (gt, hi]
double v = arr[lo];
while (i <= gt) {
if (arr[i] < v) {
ArrayUtils.swap(arr, i++, lt++);
} else if (arr[i] > v) {
ArrayUtils.swap(arr, i, gt--);
} else {
i++;
}
}
sort(arr, lo, lt - 1);
sort(arr, gt + 1, hi);
}
进阶版的三路快排
这个版本也是为了解决相同数很多时快速排序比较低效的问题,于上一个版本不同的是,刚开始会把相同的元素放在两端。
public static void sort(double[] arr) {
sort(arr, 0, arr.length - 1);
}
private static void sort(double[] arr, int lo, int hi) {
if (hi - lo <= 16) {
InsertionSort.sort(arr, lo, hi);
return;
}
reSortMidItem(arr, lo, hi);
int i = lo, j = hi + 1, p = lo, q = hi + 1;
while (true) {
if (i > lo && arr[i] == arr[lo]) {
ArrayUtils.swap(arr, i, ++p);
}
if (j <= hi && arr[j] == arr[lo]) {
ArrayUtils.swap(arr, j, --q);
}
while (ArrayUtils.less(arr, ++i, lo)) ;
while (ArrayUtils.less(arr, lo, --j)) ;
// 两个都等于 lo 时,也会在后面 break 掉,此时单独处理一下
if (i == j && arr[i] == arr[lo]) {
ArrayUtils.swap(arr, ++p, i);
}
if (i >= j) {
break;
}
ArrayUtils.swap(arr, i, j);
}
i = j + 1;
for (int k = lo; k <= p; k++) {
ArrayUtils.swap(arr, k, j--);
}
for (int k = hi; k >= q; k--) {
ArrayUtils.swap(arr, k, i++);
}
sort(arr, lo, j);
sort(arr, i, hi);
}
private static void reSortMidItem(double[] arr, int lo, int hi) {
int mid = lo + (hi - lo) / 2;
int maxIndex = lo;
if (ArrayUtils.less(arr, lo, mid)) {
maxIndex = mid;
}
if (ArrayUtils.less(arr, maxIndex, hi)) {
maxIndex = hi;
}
ArrayUtils.swap(arr, maxIndex, hi);
if (ArrayUtils.less(arr, mid, lo)) {
ArrayUtils.swap(arr, mid, lo);
}
}