JDK1.8源码之Collections.sort()

757 阅读3分钟

JDK存储一组Object的集合框架是Collection。而针对Collection框架的一组操作集合体是Collections,里面包含了多种针对Collection的操作,例如:排序、查找、交换、反转、复制等。

这一篇讲述Collections的排序操作。

public static <T extends Comparable<? super T>> void sort(List<T> list) {
    list.sort(null);
}

我们可以发现Collections.sort()调用的其实是List接口里的默认方法sort(),代码如下:

@SuppressWarnings({"unchecked", "rawtypes"})
default void sort(Comparator<? super E> c) {
    Object[] a = this.toArray();
    Arrays.sort(a, (Comparator) c);
    ListIterator<E> i = this.listIterator();
    for (Object e : a) {
        i.next();
        i.set((E) e);
    }
}

在List.sort()方法中调用Arrays.sort();方法,传入数组a和比较器c,接下来我们看看Arrays类中的sort方法,代码如下:

public static <T> void sort(T[] a, Comparator<? super T> c) {
    if (c == null) {
        sort(a);
    } else {
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a, c);
        else
            TimSort.sort(a, 0, a.length, c, null, 0, 0);
    }
}

从上面代码可以看出,如果有自定义Comparator则调用TimSort.sort(),没有自定义Comparator则调用ComparableTimSort.sort(),其实这两个方法的实现逻辑是一样的,只是java.util.TimSort.sort可以使用自定义的Comparator,而java.util.ComparableTimSort.sort不使用Comparator而已。

ComparableTimSort.sort

static void sort(Object[] a, int lo, int hi, Object[] work, int workBase, int workLen) {
    assert a != null && lo >= 0 && lo <= hi && hi <= a.length;

    int nRemaining  = hi - lo;
    if (nRemaining < 2)
        return; 

    // 数组长度小于32
    if (nRemaining < MIN_MERGE) {
        // 找出最大的递增或者递减的个数,如果递减,则此段数组严格反一下方向
        int initRunLen = countRunAndMakeAscending(a, lo, hi);
        binarySort(a, lo, hi, lo + initRunLen);
        return;
    }

    // 数组长度大于等于32
        ````````
    /**
     *  从左到右在阵列上行进一次,寻找自然路线,
     *  将短的自然运行扩展到MIN元素,并合并运行
     *  保持堆栈不变。
     */
    ComparableTimSort ts = new ComparableTimSort(a, work, workBase, workLen);
    int minRun = minRunLength(nRemaining);
    do {
        // Identify next run
        int runLen = countRunAndMakeAscending(a, lo, hi);

        // If run is short, extend to min(minRun, nRemaining)
        if (runLen < minRun) {
            int force = nRemaining <= minRun ? nRemaining : minRun;
            binarySort(a, lo, lo + force, lo + runLen);
            runLen = force;
        }

        // Push run onto pending-run stack, and maybe merge
        ts.pushRun(lo, runLen);
        ts.mergeCollapse();

        // Advance to find next run
        lo += runLen;
        nRemaining -= runLen;
    } while (nRemaining != 0);

    // Merge all remaining runs to complete sort
    assert lo == hi;
    ts.mergeForceCollapse();
    assert ts.stackSize == 1;
}

主要步骤:

数组长度小于32时

@SuppressWarnings({"fallthrough", "rawtypes", "unchecked"})
private static void binarySort(Object[] a, int lo, int hi, int start) {
    // 索引小于start的元素均已排好序
    
    assert lo <= start && start <= hi;
    if (start == lo)
        start++;
    for ( ; start < hi; start++) {
        Comparable pivot = (Comparable) a[start];

        // Set left (and right) to the index where a[start] (pivot) belongs
        int left = lo;
        int right = start;
        assert left <= right;
        // 二分查找
        while (left < right) {
            int mid = (left + right) >>> 1;
            if (pivot.compareTo(a[mid]) < 0)
                right = mid;
            else
                left = mid + 1;
        }
        assert left == right;

        /*
         * The invariants still hold: pivot >= all in [lo, left) and
         * pivot < all in [left, start), so pivot belongs at left.  Note
         * that if there are elements equal to pivot, left points to the
         * first slot after them -- that's why this sort is stable.
         * Slide elements over to make room for pivot.
         */
        int n = start - left;  // 需要移动的元素个数
        // 空出插入pivot的位置
        switch (n) {
            case 2:  a[left + 2] = a[left + 1];
            case 1:  a[left + 1] = a[left];
                     break;
            default: System.arraycopy(a, left, a, left + 1, n);
        }
        a[left] = pivot;
    }
}

数组长度大于等于32时

static void sort(Object[] a, int lo, int hi, Object[] work, int workBase, int workLen) {
        //数组个数小于32的时候
        ......
        
        // 数组个数大于32的时候
        /**
         * March over the array once, left to right, finding natural runs,
         * extending short natural runs to minRun elements, and merging runs
         * to maintain stack invariant.
         */
        ComparableTimSort ts = new ComparableTimSort(a, work, workBase, workLen);
        // 计算run的长度
        int minRun = minRunLength(nRemaining);
        do {
            // Identify next run
            // 找出连续升序的最大个数
            int runLen = countRunAndMakeAscending(a, lo, hi);

            // If run is short, extend to min(minRun, nRemaining)
            // 如果run长度小于规定的minRun长度,先进行二分插入排序
            if (runLen < minRun) {
                int force = nRemaining <= minRun ? nRemaining : minRun;
                binarySort(a, lo, lo + force, lo + runLen);
                runLen = force;
            }

            // Push run onto pending-run stack, and maybe merge
            ts.pushRun(lo, runLen);
            // 进行归并
            ts.mergeCollapse();

            // Advance to find next run
            lo += runLen;
            nRemaining -= runLen;
        } while (nRemaining != 0);

        // Merge all remaining runs to complete sort
        assert lo == hi;
        // 归并所有的run
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;
    }
  1. 计算出run的最小的长度minRun

    a)  如果数组大小为2的N次幂,则返回16(MIN_MERGE / 2);

    b)  其他情况下,逐位向右位移(即除以2),直到找到介于16和32间的一个数;

/**
     * Returns the minimum acceptable run length for an array of the specified
     * length. Natural runs shorter than this will be extended with
     * {@link #binarySort}.
     *
     * Roughly speaking, the computation is:
     *
     *  If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
     *  Else if n is an exact power of 2, return MIN_MERGE/2.
     *  Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
     *   is close to, but strictly less than, an exact power of 2.
     *
     * For the rationale, see listsort.txt.
     *
     * @param n the length of the array to be sorted
     * @return the length of the minimum run to be merged
     */
    private static int minRunLength(int n) {
        assert n >= 0;
        int r = 0;      // Becomes 1 if any 1 bits are shifted off
        while (n >= MIN_MERGE) {
            r |= (n & 1);
            n >>= 1;
        }
        return n + r;
    }
  1. 求最小递增的长度,如果长度小于minRun,使用插入排序补充到minRun的个数,操作和小于32的个数是一样。 
  2. 用stack记录每个run的长度,当下面的条件其中一个成立时归并,直到数量不变:
runLen[i - 3] > runLen[i - 2] + runLen[i - 1] 
runLen[i - 2] > runLen[i - 1]
/**
     * Examines the stack of runs waiting to be merged and merges adjacent runs
     * until the stack invariants are reestablished:
     *
     *     1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
     *     2. runLen[i - 2] > runLen[i - 1]
     *
     * This method is called each time a new run is pushed onto the stack,
     * so the invariants are guaranteed to hold for i < stackSize upon
     * entry to the method.
     */
    private void mergeCollapse() {
        while (stackSize > 1) {
            int n = stackSize - 2;
            if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
                if (runLen[n - 1] < runLen[n + 1])
                    n--;
                mergeAt(n);
            } else if (runLen[n] <= runLen[n + 1]) {
                mergeAt(n);
            } else {
                break; // Invariant is established
            }
        }
    }

关于归并方法和对一般的归并排序做出了简单的优化。假设两个 run 是 run1,run2 ,先用 gallopRight在 run1 里使用 binarySearch 查找run2 首元素 的位置k,那么 run1 中 k 前面的元素就是合并后最小的那些元素。然后,在run2 中查找run1 尾元素 的位置 len2,那么run2 中 len2 后面的那些元素就是合并后最大的那些元素。最后,根据len1 与len2 大小,调用mergeLo 或者 mergeHi 将剩余元素合并。

    /**
     * Merges the two runs at stack indices i and i+1.  Run i must be
     * the penultimate or antepenultimate run on the stack.  In other words,
     * i must be equal to stackSize-2 or stackSize-3.
     *
     * @param i stack index of the first of the two runs to merge
     */
    @SuppressWarnings("unchecked")
    private void mergeAt(int i) {
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;

        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * Record the length of the combined runs; if i is the 3rd-last
         * run now, also slide over the last run (which isn't involved
         * in this merge).  The current run (i+1) goes away in any case.
         */
        runLen[i] = len1 + len2;
        if (i == stackSize - 3) {
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        }
        stackSize--;

        /*
         * Find where the first element of run2 goes in run1. Prior elements
         * in run1 can be ignored (because they're already in place).
         */
        int k = gallopRight((Comparable<Object>) a[base2], a, base1, len1, 0);
        assert k >= 0;
        base1 += k;
        len1 -= k;
        if (len1 == 0)
            return;

        /*
         * Find where the last element of run1 goes in run2. Subsequent elements
         * in run2 can be ignored (because they're already in place).
         */
        len2 = gallopLeft((Comparable<Object>) a[base1 + len1 - 1], a,
                base2, len2, len2 - 1);
        assert len2 >= 0;
        if (len2 == 0)
            return;

        // Merge remaining runs, using tmp array with min(len1, len2) elements
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    }
  1. 最后归并还有没有归并的run,知道run的数量为1。

例子

为了演示方便,我将TimSort中的minRun直接设置为2,否则我不能用很小的数组演示。同时把MIN_MERGE也改成2(默认为32),这样避免直接进入二分插入排序。

1.  初始数组为[7,5,1,2,6,8,10,12,4,3,9,11,13,15,16,14]

2.  寻找第一个连续的降序或升序序列:[1,5,7] [2,6,8,10,12,4,3,9,11,13,15,16,14]

3.  stackSize=1,所以不合并,继续找第二个run

4.  找到一个递减序列,调整次序:[1,5,7] [2,6,8,10,12] [4,3,9,11,13,15,16,14]

5.  因为runLen[0] <= runLen[1]所以归并 

  1) gallopRight:寻找run1的第一个元素应当插入run0中哪个位置(”2”应当插入”1”之后),然后就可以忽略之前run0的元素(都比run1的第一个元素小) 

  2) gallopLeft:寻找run0的最后一个元素应当插入run1中哪个位置(”7”应当插入”8”之前),然后就可以忽略之后run1的元素(都比run0的最后一个元素大) 

  这样需要排序的元素就仅剩下[5,7] [2,6],然后进行mergeLow 完成之后的结果: [1,2,5,6,7,8,10,12] [4,3,9,11,13,15,16,14]

6.  寻找连续的降序或升序序列[1,2,5,6,7,8,10,12] [3,4] [9,11,13,15,16,14]

7.  不进行归并排序,因为runLen[0] > runLen[1]

8.  寻找连续的降序或升序序列:[1,2,5,6,7,8,10,12] [3,4] [9,11,13,15,16] [14]

9.  因为runLen[1] <= runLen[2],所以需要归并

  1. 使用gallopRight,发现为正常顺序。得[1,2,5,6,7,8,10,12] [3,4,9,11,13,15,16] [14]

  2. 最后只剩下[14]这个元素:[1,2,5,6,7,8,10,12] [3,4,9,11,13,15,16] [14]

  3. 因为runLen[0] <= runLen[1] + runLen[2]所以合并。因为runLen[0] > runLen[2],所以将run1和run2先合并。(否则将run0和run1先合并) 
    完成之后的结果: [1,2,5,6,7,8,10,12] [3,4,9,11,13,14,15,16]

  4. 完成之后的结果:[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]