Go 标准库排序原理分析
排序的主要代码在 sort.go 这个文件里。实现的排序算法有: 插入排序(insertionSort)、堆排序(heapSort)、快速排序(quickSort)、希尔排序(ShellSort)和归并排序(SymMerge)。
sort 包根据稳定性,将排序方法分为两类:不稳定排序和稳定排序
不稳定排序
不稳定排序大部分情况通过 **“希尔排序”+“堆排序”+“快速排序”**的融合改进版实现的
下面是go1.17.5中实现方法
// Sort sorts data.
// It makes one call to data.Len to determine n and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}
// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
var depth int
for i := n; i > 0; i >>= 1 {
depth++
}
return depth * 2
}
func quickSort(data Interface, a, b, maxDepth int) {
// a是第一个索引,b 是最后一个索引。如果 slice 长度大于 12,执行如下“堆排序”或“快速排序”
for b-a > 12 {
// 如果递归到了最大深度, 就使用堆排序
if maxDepth == 0 {
heapSort(data, a, b)
return
}
// 循环一次, 最大深度 -1, 相当于又深入(递归)了一层
maxDepth--
// 这是使用的是 三向切分快速排序,通过 doPivot 进行快排的分区
// doPivot 的实现比较复杂,a 是数据集的左边, b 是数据集的右边,
// 它取一点为分割点,把不大于中位数的元素放左边,大于分割点的元素放右边,
// 返回小于中位数部分数据的最后一个下标,以及大于分割点部分数据的第一个下标。
// 下标位置 a...mlo,pivot,mhi...b
// data[a...mlo] <= data[pivot]
// data[mhi...b] > data[pivot]
mlo, mhi := doPivot(data, a, b)
// 避免较大规模的子问题递归调用,保证栈深度最大为 maxDepth
// 解释:因为循环肯定比递归调用节省时间,但是两个子问题只能一个进行循环,另一个只能用递归。
// 这里是把较小规模的子问题进行递归,较大规模子问题进行循环。
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}
}
// 如果元素的个数小于 12 个(无论是递归的还是首次进入), 就先使用"希尔排序",间隔 d=6
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
堆排序
构建最大堆,通过 siftDown 来对 heap 进行调整
func heapSort(data Interface, a, b int) {
first := a
lo := 0
hi := b - a
// Build heap with greatest element at top.
for i := (hi - 1) / 2; i >= 0; i-- {
siftDown(data, i, hi, first)
}
// Pop elements, largest first, into end of data.
for i := hi - 1; i >= 0; i-- {
data.Swap(first, first+i)
siftDown(data, lo, i, first)
}
}
// siftDown implements the heap property on data[lo:hi].
// first is an offset into the array where the root of the heap lies.
func siftDown(data Interface, lo, hi, first int) {
root := lo
for {
child := 2*root + 1
if child >= hi {
break
}
if child+1 < hi && data.Less(first+child, first+child+1) {
child++
}
if !data.Less(first+root, first+child) {
return
}
data.Swap(first+root, first+child)
root = child
}
}
快速排序
快排最坏时间复杂度是 O(n**2)。最坏情况是每次切分的切片极不均衡,可能全是大于分割点的部分,也可能全是不大于分割点的部分。所以选择合适的分割点是很必要的。
doPivot 在切分之前,先使用 medianOfThree 函数选择一个肯定不是最大和最小的值作为分割点,放在了切片首位。然后把不小于 data[pivot] 的数据放在了 [lo, b) 区间,把大于 data[pivot] 的数据放在了 (c, hi-1] 区间(其中 data[hi-1] >= data[pivot])。
之后,该算法又估算了等于 data[pivot] 的数量,如果数量过多,则把与 data[pivot] 相等的数据放到了中间部分 区间为(b, c-1)。最后把 data[pivot] 交换到了 b-1 的位置。
至此,数据被切分成三个区间。 data[lo, b-1) data[b-1, c) data[c, hi)
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
if hi-lo > 40 {
// Tukey's ``Ninther,'' median of three medians of three.
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
// Invariants are:
// data[lo] = pivot (set up by ChoosePivot)
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let's be a bit more conservative, and set border to 5.
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
// Swap pivot into middle
data.Swap(pivot, b-1)
return b - 1, c
}
// medianOfThree moves the median of the three values data[m0], data[m1], data[m2] into data[m1].
func medianOfThree(data Interface, m1, m0, m2 int) {
// sort 3 elements
if data.Less(m1, m0) {
data.Swap(m1, m0)
}
// data[m0] <= data[m1]
if data.Less(m2, m1) {
data.Swap(m2, m1)
// data[m0] <= data[m2] && data[m1] < data[m2]
if data.Less(m1, m0) {
data.Swap(m1, m0)
}
}
// now data[m0] <= data[m1] <= data[m2]
}
希尔排序(改进的插入排序)
以6作为间隔,先做一次筛选排序;然后以1作为间隔,做一次插入排序
func quickSort(data Interface, a, b, maxDepth int) {
...
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
// insertionSort sorts data[a:b] using insertion sort.
func insertionSort(data Interface, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}
稳定排序
sort 包中使用的稳定排序算法“**归并排序”+“插入排序”**改进版。首先,它把 slice 按照每 blockSize=20 个元素为一个 slice,进行插入排序;循环合并相邻的两个 block,每次循环 blockSize 扩大二倍,直到 blockSize > n 为止。
// Stable sorts data while keeping the original order of equal elements.
//
// It makes one call to data.Len to determine n, O(n*log(n)) calls to
// data.Less and O(n*log(n)*log(n)) calls to data.Swap.
func Stable(data Interface) {
stable(data, data.Len())
}
func stable(data Interface, n int) {
blockSize := 20 // must be > 0
a, b := 0, blockSize
for b <= n {
insertionSort(data, a, b)
a = b
b += blockSize
}
insertionSort(data, a, n)
for blockSize < n {
a, b = 0, 2*blockSize
for b <= n {
symMerge(data, a, a+blockSize, b)
a = b
b += 2 * blockSize
}
if m := a + blockSize; m < n {
symMerge(data, a, m, n)
}
blockSize *= 2
}
}
归并排序
func symMerge(data Interface, a, m, b int) {
// Avoid unnecessary recursions of symMerge
// by direct insertion of data[a] into data[m:b]
// if data[a:m] only contains one element.
if m-a == 1 {
// Use binary search to find the lowest index i
// such that data[i] >= data[a] for m <= i < b.
// Exit the search loop with i == b in case no such index exists.
i := m
j := b
for i < j {
h := int(uint(i+j) >> 1)
if data.Less(h, a) {
i = h + 1
} else {
j = h
}
}
// Swap values until data[a] reaches the position before i.
for k := a; k < i-1; k++ {
data.Swap(k, k+1)
}
return
}
// Avoid unnecessary recursions of symMerge
// by direct insertion of data[m] into data[a:m]
// if data[m:b] only contains one element.
if b-m == 1 {
// Use binary search to find the lowest index i
// such that data[i] > data[m] for a <= i < m.
// Exit the search loop with i == m in case no such index exists.
i := a
j := m
for i < j {
h := int(uint(i+j) >> 1)
if !data.Less(m, h) {
i = h + 1
} else {
j = h
}
}
// Swap values until data[m] reaches the position i.
for k := m; k > i; k-- {
data.Swap(k, k-1)
}
return
}
mid := int(uint(a+b) >> 1)
n := mid + m
var start, r int
if m > mid {
start = n - b
r = mid
} else {
start = a
r = m
}
p := n - 1
for start < r {
c := int(uint(start+r) >> 1)
if !data.Less(p-c, c) {
start = c + 1
} else {
r = c
}
}
end := n - start
if start < m && m < end {
rotate(data, start, m, end)
}
if a < start && start < mid {
symMerge(data, a, start, mid)
}
if mid < end && end < b {
symMerge(data, mid, end, b)
}
}
插入排序
同不稳定排序中的希尔排序,当间隔变为1时为插入排序
// insertionSort sorts data[a:b] using insertion sort.
func insertionSort(data Interface, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}
Go排序写法
自定义比较器写法(匿名函数)
type Person struct {
Name string
Age int
}
func main() {
family := []Person{
{"Alice", 23},
{"David", 2},
{"Eve", 2},
{"Bob", 25},
}
// Sort by age, keeping original order or equal elements. (从小到大)
sort.Slice(family, func(i, j int) bool {
return family[i].Age < family[j].Age
})
fmt.Println(family) // [{David 2} {Eve 2} {Alice 23} {Bob 25}]
}
实现sort.Interface写法
type Person struct{
Name string
Age int
}
type Family []Person
func (a Family) Len() int { return len(a) }
func (a Family) Less(i, j int) bool { return a[i].Age < a[j].Age } // 从小到大
func (a Family) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func main() {
family := []Person{
{"Alice", 23},
{"David", 2},
{"Eve", 2},
{"Bob", 25},
}
sort.Sort(Family(family))
fmt.Println(family) //{David 2} {Eve 2} {Alice 23} {Bob 25}
}
常用数据类型写法
对于 []int, []float64, []string 基础类型的切片使用 sort 包提供的下面几个函数进行排序。
sort.Ints
sort.Float64s
sort.Strings
func main() {
s := []int{4, 2, 3, 1}
sort.Ints(s)
fmt.Println(s) // 输出[1 2 3 4]
v := []string{"qwb", "qwa", "wwa"}
sort.Strings(v)
fmt.Println(v) // [qwa qwb wwa]
w := []float64{12.31, 12.21, 13.54}
sort.Float64s(w)
fmt.Println(w) // [12.21 12.31 13.54]
}
附:
七大排序算法与Go实现
- 选择排序
// 选择排序 (selection sort)package sorts
func SelectionSort(arr []int) []int {
for i := 0; i < len(arr); i++ {
min := i
for j := i + 1; j < len(arr); j++ {
if arr[j] < arr[min] {
min = j
}
}
tmp := arr[i]
arr[i] = arr[min]
arr[min] = tmp
}
return arr
}
- 冒泡排序
// 冒泡排序 (bubble sort)package sorts
func bubbleSort(arr []int) []int {
swapped := truefor swapped {
swapped = falsefor i := 0; i < len(arr)-1; i++ {
if arr[i+1] < arr[i] {
arr[i+1], arr[i] = arr[i], arr[i+1]
swapped = true
}
}
}
return arr
}
- 插入排序
// 插入排序 (insertion sort)package sorts
func InsertionSort(arr []int) []int {
for currentIndex := 1; currentIndex < len(arr); currentIndex++ {
temporary := arr[currentIndex]
iterator := currentIndex
for ; iterator > 0 && arr[iterator-1] >= temporary; iterator-- {
arr[iterator] = arr[iterator-1]
}
arr[iterator] = temporary
}
return arr
}
- 希尔排序
// 希尔排序 (shell sort)package sorts
func ShellSort(arr []int) []int {
for d := int(len(arr) / 2); d > 0; d /= 2 {
for i := d; i < len(arr); i++ {
for j := i; j >= d && arr[j-d] > arr[j]; j -= d {
arr[j], arr[j-d] = arr[j-d], arr[j]
}
}
}
return arr
}
- 归并排序
// 归并排序 (merge sort)package sorts
func merge(a []int, b []int) []int {
var r = make([]int, len(a)+len(b))
var i = 0var j = 0for i < len(a) && j < len(b) {
if a[i] <= b[j] {
r[i+j] = a[i]
i++
} else {
r[i+j] = b[j]
j++
}
}
for i < len(a) {
r[i+j] = a[i]
i++
}
for j < len(b) {
r[i+j] = b[j]
j++
}
return r
}
// Mergesort 合并两个数组
func Mergesort(items []int) []int {
if len(items) < 2 {
return items
}
var middle = len(items) / 2var a = Mergesort(items[:middle])
var b = Mergesort(items[middle:])
return merge(a, b)
}
- 快速排序
// 三向切分快速排序 (quick sort)package sorts
import (
"math/rand"
)
func QuickSort(arr []int) []int {
if len(arr) <= 1 {
return arr
}
pivot := arr[rand.Intn(len(arr))]
lowPart := make([]int, 0, len(arr))
highPart := make([]int, 0, len(arr))
middlePart := make([]int, 0, len(arr))
for _, item := range arr {
switch {
case item < pivot:
lowPart = append(lowPart, item)
case item == pivot:
middlePart = append(middlePart, item)
case item > pivot:
highPart = append(highPart, item)
}
}
lowPart = QuickSort(lowPart)
highPart = QuickSort(highPart)
lowPart = append(lowPart, middlePart...)
lowPart = append(lowPart, highPart...)
return lowPart
}
- 堆排序
// 堆排序 (heap sort)
package sorts
type maxHeap struct {
slice []int
heapSize int
}
func buildMaxHeap(slice []int) maxHeap {
h := maxHeap{slice: slice, heapSize: len(slice)}
for i := len(slice) / 2; i >= 0; i-- {
h.MaxHeapify(i)
}
return h
}
func (h maxHeap) MaxHeapify(i int) {
l, r := 2*i+1, 2*i+2
max := i
if l < h.size() && h.slice[l] > h.slice[max] {
max = l
}
if r < h.size() && h.slice[r] > h.slice[max] {
max = r
}
if max != i {
h.slice[i], h.slice[max] = h.slice[max], h.slice[i]
h.MaxHeapify(max)
}
}
func (h maxHeap) size() int { return h.heapSize }
func HeapSort(slice []int) []int {
h := buildMaxHeap(slice)
for i := len(h.slice) - 1; i >= 1; i-- {
h.slice[0], h.slice[i] = h.slice[i], h.slice[0]
h.heapSize--
h.MaxHeapify(0)
}
return h.slice
}
参考文章
Go 排序 sort.Slice 及其他方法
常见排序算法总结和 Go 标准库排序源码分析
Go语言中文网 sort -- 排序算法
本文正在参加技术专题18期-聊聊Go语言框架