Go排序-原理分析与常用写法

335 阅读5分钟

Go 标准库排序原理分析

排序的主要代码在 sort.go 这个文件里。实现的排序算法有: 插入排序(insertionSort)、堆排序(heapSort)、快速排序(quickSort)、希尔排序(ShellSort)和归并排序(SymMerge)。

sort 包根据稳定性,将排序方法分为两类:不稳定排序和稳定排序

不稳定排序

不稳定排序大部分情况通过 **“希尔排序”+“堆排序”+“快速排序”**的融合改进版实现的

下面是go1.17.5中实现方法

WX20220729-171949@2x.png

// Sort sorts data.
// It makes one call to data.Len to determine n and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
   n := data.Len()
   quickSort(data, 0, n, maxDepth(n))
}


// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
   var depth int
   for i := n; i > 0; i >>= 1 {
      depth++
   }
   return depth * 2
}

func quickSort(data Interface, a, b, maxDepth int) {
    // a是第一个索引,b 是最后一个索引。如果 slice 长度大于 12,执行如下“堆排序”或“快速排序”
    for b-a > 12 {
        // 如果递归到了最大深度, 就使用堆排序
        if maxDepth == 0 {
            heapSort(data, a, b)
            return
        }
        // 循环一次, 最大深度 -1, 相当于又深入(递归)了一层
        maxDepth--
        // 这是使用的是 三向切分快速排序,通过 doPivot 进行快排的分区
        // doPivot 的实现比较复杂,a 是数据集的左边, b 是数据集的右边,
        // 它取一点为分割点,把不大于中位数的元素放左边,大于分割点的元素放右边,
        // 返回小于中位数部分数据的最后一个下标,以及大于分割点部分数据的第一个下标。
        // 下标位置 a...mlo,pivot,mhi...b
        // data[a...mlo] <= data[pivot]
        // data[mhi...b] > data[pivot]
        mlo, mhi := doPivot(data, a, b)
        // 避免较大规模的子问题递归调用,保证栈深度最大为 maxDepth
        // 解释:因为循环肯定比递归调用节省时间,但是两个子问题只能一个进行循环,另一个只能用递归。
        //      这里是把较小规模的子问题进行递归,较大规模子问题进行循环。
        if mlo-a < b-mhi {
            quickSort(data, a, mlo, maxDepth)
            a = mhi // i.e., quickSort(data, mhi, b)
        } else {
            quickSort(data, mhi, b, maxDepth)
            b = mlo // i.e., quickSort(data, a, mlo)
        }
    }

    // 如果元素的个数小于 12 个(无论是递归的还是首次进入), 就先使用"希尔排序",间隔 d=6
    if b-a > 1 {
        // Do ShellSort pass with gap 6
        // It could be written in this simplified form cause b-a <= 12
        for i := a + 6; i < b; i++ {
            if data.Less(i, i-6) {
                data.Swap(i, i-6)
            }
        }
        insertionSort(data, a, b)
    }
}

堆排序

构建最大堆,通过 siftDown 来对 heap 进行调整

func heapSort(data Interface, a, b int) {
   first := a
   lo := 0
   hi := b - a

   // Build heap with greatest element at top.
   for i := (hi - 1) / 2; i >= 0; i-- {
      siftDown(data, i, hi, first)
   }

   // Pop elements, largest first, into end of data.
   for i := hi - 1; i >= 0; i-- {
      data.Swap(first, first+i)
      siftDown(data, lo, i, first)
   }
}

// siftDown implements the heap property on data[lo:hi].
// first is an offset into the array where the root of the heap lies.
func siftDown(data Interface, lo, hi, first int) {
   root := lo
   for {
      child := 2*root + 1
      if child >= hi {
         break
      }
      if child+1 < hi && data.Less(first+child, first+child+1) {
         child++
      }
      if !data.Less(first+root, first+child) {
         return
      }
      data.Swap(first+root, first+child)
      root = child
   }
}

快速排序

快排最坏时间复杂度是 O(n**2)。最坏情况是每次切分的切片极不均衡,可能全是大于分割点的部分,也可能全是不大于分割点的部分。所以选择合适的分割点是很必要的。

doPivot 在切分之前,先使用 medianOfThree 函数选择一个肯定不是最大和最小的值作为分割点,放在了切片首位。然后把不小于 data[pivot] 的数据放在了 [lo, b) 区间,把大于 data[pivot] 的数据放在了 (c, hi-1] 区间(其中 data[hi-1] >= data[pivot])。

之后,该算法又估算了等于 data[pivot] 的数量,如果数量过多,则把与 data[pivot] 相等的数据放到了中间部分 区间为(b, c-1)。最后把 data[pivot] 交换到了 b-1 的位置。

至此,数据被切分成三个区间。 data[lo, b-1) data[b-1, c) data[c, hi)

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
   m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
   if hi-lo > 40 {
      // Tukey's ``Ninther,'' median of three medians of three.
      s := (hi - lo) / 8
      medianOfThree(data, lo, lo+s, lo+2*s)
      medianOfThree(data, m, m-s, m+s)
      medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
   }
   medianOfThree(data, lo, m, hi-1)

   // Invariants are:
   // data[lo] = pivot (set up by ChoosePivot)
   // data[lo < i < a] < pivot
   // data[a <= i < b] <= pivot
   // data[b <= i < c] unexamined
   // data[c <= i < hi-1] > pivot
   // data[hi-1] >= pivot
   pivot := lo
   a, c := lo+1, hi-1

   for ; a < c && data.Less(a, pivot); a++ {
   }
   b := a
   for {
      for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
      }
      for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
      }
      if b >= c {
         break
      }
      // data[b] > pivot; data[c-1] <= pivot
      data.Swap(b, c-1)
      b++
      c--
   }
   // If hi-c<3 then there are duplicates (by property of median of nine).
   // Let's be a bit more conservative, and set border to 5.
   protect := hi-c < 5
   if !protect && hi-c < (hi-lo)/4 {
      // Lets test some points for equality to pivot
      dups := 0
      if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
         data.Swap(c, hi-1)
         c++
         dups++
      }
      if !data.Less(b-1, pivot) { // data[b-1] = pivot
         b--
         dups++
      }
      // m-lo = (hi-lo)/2 > 6
      // b-lo > (hi-lo)*3/4-1 > 8
      // ==> m < b ==> data[m] <= pivot
      if !data.Less(m, pivot) { // data[m] = pivot
         data.Swap(m, b-1)
         b--
         dups++
      }
      // if at least 2 points are equal to pivot, assume skewed distribution
      protect = dups > 1
   }
   if protect {
      // Protect against a lot of duplicates
      // Add invariant:
      // data[a <= i < b] unexamined
      // data[b <= i < c] = pivot
      for {
         for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
         }
         for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
         }
         if a >= b {
            break
         }
         // data[a] == pivot; data[b-1] < pivot
         data.Swap(a, b-1)
         a++
         b--
      }
   }
   // Swap pivot into middle
   data.Swap(pivot, b-1)
   return b - 1, c
}

// medianOfThree moves the median of the three values data[m0], data[m1], data[m2] into data[m1].
func medianOfThree(data Interface, m1, m0, m2 int) {
   // sort 3 elements
   if data.Less(m1, m0) {
      data.Swap(m1, m0)
   }
   // data[m0] <= data[m1]
   if data.Less(m2, m1) {
      data.Swap(m2, m1)
      // data[m0] <= data[m2] && data[m1] < data[m2]
      if data.Less(m1, m0) {
         data.Swap(m1, m0)
      }
   }
   // now data[m0] <= data[m1] <= data[m2]
}

希尔排序(改进的插入排序)

以6作为间隔,先做一次筛选排序;然后以1作为间隔,做一次插入排序

func quickSort(data Interface, a, b, maxDepth int) {
   ...
   if b-a > 1 {
      // Do ShellSort pass with gap 6
      // It could be written in this simplified form cause b-a <= 12
      for i := a + 6; i < b; i++ {
         if data.Less(i, i-6) {
            data.Swap(i, i-6)
         }
      }
      insertionSort(data, a, b)
   }
}

// insertionSort sorts data[a:b] using insertion sort.
func insertionSort(data Interface, a, b int) {
   for i := a + 1; i < b; i++ {
      for j := i; j > a && data.Less(j, j-1); j-- {
         data.Swap(j, j-1)
      }
   }
}

稳定排序

sort 包中使用的稳定排序算法“**归并排序”+“插入排序”**改进版。首先,它把 slice 按照每 blockSize=20 个元素为一个 slice,进行插入排序;循环合并相邻的两个 block,每次循环 blockSize 扩大二倍,直到 blockSize > n 为止。

WX20220729-172121@2x.png

// Stable sorts data while keeping the original order of equal elements.
//
// It makes one call to data.Len to determine n, O(n*log(n)) calls to
// data.Less and O(n*log(n)*log(n)) calls to data.Swap.
func Stable(data Interface) {
   stable(data, data.Len())
}

func stable(data Interface, n int) {
   blockSize := 20 // must be > 0
   a, b := 0, blockSize
   for b <= n {
      insertionSort(data, a, b)
      a = b
      b += blockSize
   }
   insertionSort(data, a, n)

   for blockSize < n {
      a, b = 0, 2*blockSize
      for b <= n {
         symMerge(data, a, a+blockSize, b)
         a = b
         b += 2 * blockSize
      }
      if m := a + blockSize; m < n {
         symMerge(data, a, m, n)
      }
      blockSize *= 2
   }
}

归并排序

func symMerge(data Interface, a, m, b int) {
   // Avoid unnecessary recursions of symMerge
   // by direct insertion of data[a] into data[m:b]
   // if data[a:m] only contains one element.
   if m-a == 1 {
      // Use binary search to find the lowest index i
      // such that data[i] >= data[a] for m <= i < b.
      // Exit the search loop with i == b in case no such index exists.
      i := m
      j := b
      for i < j {
         h := int(uint(i+j) >> 1)
         if data.Less(h, a) {
            i = h + 1
         } else {
            j = h
         }
      }
      // Swap values until data[a] reaches the position before i.
      for k := a; k < i-1; k++ {
         data.Swap(k, k+1)
      }
      return
   }

   // Avoid unnecessary recursions of symMerge
   // by direct insertion of data[m] into data[a:m]
   // if data[m:b] only contains one element.
   if b-m == 1 {
      // Use binary search to find the lowest index i
      // such that data[i] > data[m] for a <= i < m.
      // Exit the search loop with i == m in case no such index exists.
      i := a
      j := m
      for i < j {
         h := int(uint(i+j) >> 1)
         if !data.Less(m, h) {
            i = h + 1
         } else {
            j = h
         }
      }
      // Swap values until data[m] reaches the position i.
      for k := m; k > i; k-- {
         data.Swap(k, k-1)
      }
      return
   }

   mid := int(uint(a+b) >> 1)
   n := mid + m
   var start, r int
   if m > mid {
      start = n - b
      r = mid
   } else {
      start = a
      r = m
   }
   p := n - 1

   for start < r {
      c := int(uint(start+r) >> 1)
      if !data.Less(p-c, c) {
         start = c + 1
      } else {
         r = c
      }
   }

   end := n - start
   if start < m && m < end {
      rotate(data, start, m, end)
   }
   if a < start && start < mid {
      symMerge(data, a, start, mid)
   }
   if mid < end && end < b {
      symMerge(data, mid, end, b)
   }
}

插入排序

同不稳定排序中的希尔排序,当间隔变为1时为插入排序

// insertionSort sorts data[a:b] using insertion sort.
func insertionSort(data Interface, a, b int) {
   for i := a + 1; i < b; i++ {
      for j := i; j > a && data.Less(j, j-1); j-- {
         data.Swap(j, j-1)
      }
   }
}

Go排序写法

自定义比较器写法(匿名函数)

type Person struct {
   Name string
   Age  int
}

func main() {
   family := []Person{
      {"Alice", 23},
      {"David", 2},
      {"Eve", 2},
      {"Bob", 25},
   }
   // Sort by age, keeping original order or equal elements. (从小到大)
   sort.Slice(family, func(i, j int) bool {
      return family[i].Age < family[j].Age
   })
   fmt.Println(family)  // [{David 2} {Eve 2} {Alice 23} {Bob 25}]
}

实现sort.Interface写法

type Person struct{
    Name string
    Age  int
}

type Family []Person

func (a Family) Len() int           { return len(a) }
func (a Family) Less(i, j int) bool { return a[i].Age < a[j].Age }  // 从小到大
func (a Family) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }

func main() {
   family := []Person{
      {"Alice", 23},
      {"David", 2},
      {"Eve", 2},
      {"Bob", 25},
   }
   sort.Sort(Family(family))
   fmt.Println(family)  //{David 2} {Eve 2} {Alice 23} {Bob 25}
}

常用数据类型写法

对于 []int, []float64, []string 基础类型的切片使用 sort 包提供的下面几个函数进行排序。

sort.Ints

sort.Float64s

sort.Strings

func main() {
   s := []int{4, 2, 3, 1}
   sort.Ints(s)
   fmt.Println(s) // 输出[1 2 3 4]

   v := []string{"qwb", "qwa", "wwa"}
   sort.Strings(v)
   fmt.Println(v) // [qwa qwb wwa]

   w := []float64{12.31, 12.21, 13.54}
   sort.Float64s(w)
   fmt.Println(w)  // [12.21 12.31 13.54]
}

附:

七大排序算法与Go实现

  • 选择排序
// 选择排序 (selection sort)package sorts

func SelectionSort(arr []int) []int {
    for i := 0; i < len(arr); i++ {
        min := i
        for j := i + 1; j < len(arr); j++ {
            if arr[j] < arr[min] {
                min = j
            }
        }
        tmp := arr[i]
        arr[i] = arr[min]
        arr[min] = tmp
    }
    return arr
}
  • 冒泡排序
// 冒泡排序 (bubble sort)package sorts

func bubbleSort(arr []int) []int {
    swapped := truefor swapped {
        swapped = falsefor i := 0; i < len(arr)-1; i++ {
            if arr[i+1] < arr[i] {
                arr[i+1], arr[i] = arr[i], arr[i+1]
                swapped = true
            }
        }
    }
    return arr
}
  • 插入排序
// 插入排序 (insertion sort)package sorts

func InsertionSort(arr []int) []int {
    for currentIndex := 1; currentIndex < len(arr); currentIndex++ {
        temporary := arr[currentIndex]
        iterator := currentIndex
        for ; iterator > 0 && arr[iterator-1] >= temporary; iterator-- {
            arr[iterator] = arr[iterator-1]
        }
        arr[iterator] = temporary
    }
    return arr
}
  • 希尔排序
// 希尔排序 (shell sort)package sorts

func ShellSort(arr []int) []int {
    for d := int(len(arr) / 2); d > 0; d /= 2 { 
        for i := d; i < len(arr); i++ {
            for j := i; j >= d && arr[j-d] > arr[j]; j -= d {
                arr[j], arr[j-d] = arr[j-d], arr[j]
            }
        }
    }
    return arr
}
  • 归并排序
// 归并排序 (merge sort)package sorts

func merge(a []int, b []int) []int {
    var r = make([]int, len(a)+len(b))
    var i = 0var j = 0for i < len(a) && j < len(b) {
        if a[i] <= b[j] {
            r[i+j] = a[i]
            i++
        } else {
            r[i+j] = b[j]
            j++
        }
    }
    for i < len(a) {
        r[i+j] = a[i]
        i++
    }
    for j < len(b) {
        r[i+j] = b[j]
        j++
    }
    return r
}

// Mergesort 合并两个数组
func Mergesort(items []int) []int {
    if len(items) < 2 {
        return items
    }
    var middle = len(items) / 2var a = Mergesort(items[:middle])
    var b = Mergesort(items[middle:])
    return merge(a, b)
}
  • 快速排序
// 三向切分快速排序 (quick sort)package sorts

import (
    "math/rand"
)

func QuickSort(arr []int) []int {

    if len(arr) <= 1 {
        return arr
    }

    pivot := arr[rand.Intn(len(arr))]

    lowPart := make([]int, 0, len(arr))
    highPart := make([]int, 0, len(arr))
    middlePart := make([]int, 0, len(arr))

    for _, item := range arr {
        switch {
        case item < pivot:
            lowPart = append(lowPart, item)
        case item == pivot:
            middlePart = append(middlePart, item)
        case item > pivot:
            highPart = append(highPart, item)
        }
    }

    lowPart = QuickSort(lowPart)
    highPart = QuickSort(highPart)

    lowPart = append(lowPart, middlePart...)
    lowPart = append(lowPart, highPart...)

    return lowPart
}
  • 堆排序
// 堆排序 (heap sort)
package sorts

type maxHeap struct {
    slice    []int
    heapSize int
}

func buildMaxHeap(slice []int) maxHeap {
    h := maxHeap{slice: slice, heapSize: len(slice)}
    for i := len(slice) / 2; i >= 0; i-- {
        h.MaxHeapify(i)
    }
    return h
}

func (h maxHeap) MaxHeapify(i int) {
    l, r := 2*i+1, 2*i+2
    max := i

    if l < h.size() && h.slice[l] > h.slice[max] {
        max = l
    }
    if r < h.size() && h.slice[r] > h.slice[max] {
        max = r
    }
    if max != i {
        h.slice[i], h.slice[max] = h.slice[max], h.slice[i]
        h.MaxHeapify(max)
    }
}

func (h maxHeap) size() int { return h.heapSize } 

func HeapSort(slice []int) []int {
    h := buildMaxHeap(slice)
    for i := len(h.slice) - 1; i >= 1; i-- {
        h.slice[0], h.slice[i] = h.slice[i], h.slice[0]
        h.heapSize--
        h.MaxHeapify(0)
    }
    return h.slice
}

image.png

参考文章

Go 排序 sort.Slice 及其他方法

常见排序算法总结和 Go 标准库排序源码分析

Go语言中文网 sort -- 排序算法

本文正在参加技术专题18期-聊聊Go语言框架