快速排序

61 阅读3分钟

自以为很理解快速排序算法,但是在写代码时,却久久没有处理好边界情况。

多数对于快速排序的理解为:

  1. partition 步骤: 将数组分成两份,左边一份都小于等于 x (通常取值为序列中间索引对应的值),右边一份都大于等于 x (通常取值为序列中间索引对应的值)
  2. 递归调用 quickSort 步骤

基于这种理解很容易写出如下代码:

import java.util.Scanner;

class Main {

    static void quickSort(int[] arr, int l, int r) {
        if (l >= r) {
            return;
        }

        // partition.
        int x = arr[l + r >> 1];
        int[] left = new int[r - l + 1];
        int[] right = new int[r -l + 1];
        int leftCnt = 0, rightCnt = 0;
        // <=x
        for (int i = l; i <= r; i ++ ) {
            if (arr[i] <= x) {
                left[leftCnt ++] = arr[i];
            } else {
                right[rightCnt ++] = arr[i];
            }
        }
        for (int i = l; i < l + leftCnt; i ++ ) {
            arr[i] = left[i - l];
        }
        for (int i = l + leftCnt; i <= r; i ++ ) {
            arr[i] = right[i - l - leftCnt];
        }
        
        quickSort(arr, l, l + leftCnt - 1);
        quickSort(arr, l + leftCnt, r);
    }


    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        int[] arr = new int[n];
        for (int i = 0; i < n; i ++ ) {
            arr[i] = scanner.nextInt();
        }

        quickSort(arr, 0, n - 1);
        
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < n; i ++ ) {
            sb.append(arr[i]).append(" ");
        }
        System.out.println(sb);
    }
}

但是这个写法明显是错的,我们一定要记得一种特殊情况,就是如果序列中 所有的元素都等于 x , 那么这种递归就会死循环

递归时,一定要避免,左右两边,一边全占,一边为空。为此可以改进如下下代码,虽然显得比较笨,但是代码复杂度还是 O(NlogN) 并没增加,就是复杂度的常数增加了。

import java.util.Scanner;

class Main {

    static void quickSort(int[] arr, int l, int r) {
        if (l >= r) {
            return;
        }

        // partition.
        int x = arr[l + r >> 1];
        int[] left = new int[r - l + 1];
        int[] right = new int[r -l + 1];
        int[] eq = new int[r - l + 1];
        int leftCnt = 0, rightCnt = 0, eqCnt = 0;
        // <=x
        for (int i = l; i <= r; i ++ ) {
            if (arr[i] < x) {
                left[leftCnt ++] = arr[i];
            } else if (arr[i] > x) {
                right[rightCnt ++] = arr[i];
            } else {
                eq[eqCnt ++] = arr[i];
            }
        }
        for (int i = l; i < l + leftCnt; i ++ ) {
            arr[i] = left[i - l];
        }
        for (int i = l + leftCnt; i < l + leftCnt + eqCnt; i ++ ){
            arr[i] = eq[i - l - leftCnt];
        }
        for (int i = l + leftCnt + eqCnt; i <= r; i ++ ) {
            arr[i] = right[i - l - leftCnt - eqCnt];
        }
        
        if (leftCnt == 0 && rightCnt == 0) {
            // 不用排序了,都等于 x
            return;
        }
        
        if (leftCnt == 0) {
            // 把相等的都划分到左边
            quickSort(arr, l, l + leftCnt + eqCnt - 1);
            quickSort(arr, l + leftCnt + eqCnt, r);
            return;
        }
        
        // 都不为 0 时,随便处理
        quickSort(arr, l, l + leftCnt - 1);
        quickSort(arr, l + leftCnt + eqCnt, r);
        
        
    }


    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        int[] arr = new int[n];
        for (int i = 0; i < n; i ++ ) {
            arr[i] = scanner.nextInt();
        }

        quickSort(arr, 0, n - 1);

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < n; i ++ ) {
            sb.append(arr[i]).append(" ");
        }
        System.out.println(sb);
    }
} 

快排的模版代码通常写为:

import java.util.Scanner;

class Main {

   static void quickSort(int[] arr, int l, int r) {
       if (l >= r) {
           return;
       }

       // partition. 明显 parittion 保证 索引<= i 对应的值都 <=x ; 索引 >= j 对应的值都 >= x
       int x = arr[l + r >> 1];
       int i = l - 1, j = r + 1;
       while (i < j) {
           do {
               i++;
           } while (arr[i] < x && i <= r);

           do {
               j--;
           } while (arr[j] > x && j >= l);

           if (i < j) {
               int tmp = arr[i];
               arr[i] = arr[j];
               arr[j] = tmp;
           }
       }

       // j 一定大于等于 l, 所以 [l, j] 一定不为空
       quickSort(arr, l,  j);
       // 如果 j + 1 > r ,说明整个序列都 <= x 。但是这样的话, 因为 l 一定小于 r, 所以 l + r >> 1 一定小于 r,
       // 所以一定在某个小于 r 的地方 arr[i] == x ,这样一来 j 一定会至少做一次 -- 操作,所以 [j + 1, r] 一定也不为空。
       // 综上所属 ,分成的左右两段一定都不会为空,保证了不会死循环。
       quickSort(arr, j + 1, r);
   }


   public static void main(String[] args) {
       Scanner scanner = new Scanner(System.in);
       int n = scanner.nextInt();
       int[] arr = new int[n];
       for (int i = 0; i < n; i ++ ) {
           arr[i] = scanner.nextInt();
       }

       quickSort(arr, 0, n - 1);

       StringBuilder sb = new StringBuilder();
       for (int i = 0; i < n; i ++ ) {
           sb.append(arr[i]).append(" ");
       }
       System.out.println(sb);
   }
}

参考注释,就能理解,为啥 x 取的是 l + r >> 1 索引对应的值,为啥分割点取的是 j