面试中,我们遇到很多排序算法的考核。
快速排序,堆排,希尔排序,插入排序,归并排序等等...
由于现在竞争激烈,面试官除了问这些经典的排序算法,偶尔也会考核一下候选人对很多工业排序的认知,比如MySQL的order by,PHP的array系列排序函数,Go的sort包,Java的Collections.sort()。
对排序这块,确实让人头疼,存在这么几个痛处:
- 找不到比较规范的算法
- 很难完全理解规范的算法
- 好不容易学会了又会忘记
- 和工作中的内容不搭边
- 由于上面几点,面试的排序怎么都准备不踏实
这里我介绍一个比较实用的排序学习方向,也是今天要探讨的一部分:
算法知识 -> 普通练手 -> 工业源码 -> 工作场景排序
- 算法知识:特定排序对应的知识,比如插入排序的规范写法(教材),原理等
- 普通练手:理解了算法后,在编辑器中敲出来并运行,配合LeetCode之类
- 工业源码:去了解你熟悉的语言包如何实现排序,因为它们经受住了工业级别的应用,显然非常有价值
- 工作场景:现实场景中的排序,比如1000万用户的多维度排序,限定内存的多文件大文件排序等
只要踏踏实实掌握上面4点,排序算法这块就可以安心了,即使你会忘记一些细节,你在跟别人讲述排序的时候,别人都可以感受到你「由点及面」的能力。
今天我们就第三点「工业源码」的排序展开学习,由于我最近使用Go比较多,也比较生疏,正好就拿它的排序包sort来学习了。
读源码是一项脑力+体力活,难度确实不小。
作为工程师,「拆解问题」是我们的本能,对于今天这个特殊的主题:「Go语言Sort包中的工程级排序算法」,我们将它拆解成不同粒度的子问题,一直到可以理解。
我会从入口开始,一层一层分析,直接最后理解了整个排序,希望也能帮助到你。
开始:
从使用层面的粒度:
// Sort sorts data.
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}
注释里写到:Sort不是一个稳定排序算法,也就是说排序后有可能会破坏数据的稳定性,比如 [5, 1, 2, 3, 5],经过不稳定排序后,有可能最后一个元素5会跑到第一个5的前面,变成[1, 2, 3, 5(原来的最后一个), 5(原来的第一个)]
本篇关注的重点是工程级的排序算法如何实现,所以假设各位都清楚了排序相关的一些前置知识,包括:时间复杂度分析,插入排序,希尔排序,堆排序,快速排序和Go语言的基本语法,如果还不是很清楚这些知识点的话,建议先逐个击破,再回到该篇一起探讨这个算法,会更高效。
说明一点,源码的注释我都保留了,以免因为我的片面理解歪曲了作者的本意。
进入正题,先看一下Interface的结构:
// A type, typically a collection, that satisfies sort.Interface can be
// sorted by the routines in this package. The methods require that the
// elements of the collection be enumerated by an integer index.
type Interface interface {
// Len is the number of elements in the collection.
Len() int
// Less reports whether the element with
// index i should sort before the element with index j.
Less(i, j int) bool
// Swap swaps the elements with indexes i and j.
Swap(i, j int)
}
经典三大配件,无需多言。
熟悉排序的我们知道,排序过程中,比较和交换是必要环节,比较可以判断是否需要交换,而交换会减少逆序度(人话:数据集合会变得有序一些)。
了解了data要求的结构之后,我们就可以进入Sort这个入口函数了,但在这提醒一点,Less()的实现完全是可以由我们来自定义的,所以对于复杂对象的比较,比如 注册时间+用户年龄的多维度排序,我们可以自定义Less()。
往下走:
入口函数告诉我们,会调用一次data.Len来获取数据的长度,并且以O(n*log(n))的规模来调用Less和Swap,像这种说明,先看一眼,不去细究。
重点来了:
quickSort(data, 0, n, maxDepth(n))
这个 maxDepth(n)是干什么用的? 我们先大概看一眼它的功能,不要太深究背后的数学原理。
根据数据规模,不断二分,累计深度。
// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
var depth int
for i := n; i > 0; i >>= 1 {
depth++
}
//下面这行不要纠结,后面会解释
return depth * 2
}
注释告诉我们:maxDepth返回一个适当的阈值,用于判断当前的排序算法应不应该由快排切换到堆排,这里maxDepth的作用就是限制递归的深度。递归的空间成本很高,面对巨大规模的数据,递归还有内存溢出风险,maxDepth的存在规避了风险。 这里我们可能会好奇最后return的时候为什么乘于2,没事,后面进入快排的时候,我会解释一下。
看完maxDepth,我们回到quickSort。
func quickSort(data Interface, a, b, maxDepth int) {
for b-a > 12 { // Use ShellSort for slices <= 12 elements
if maxDepth == 0 {
heapSort(data, a, b)
return
}
maxDepth--
mlo, mhi := doPivot(data, a, b)
// Avoiding recursion on the larger subproblem guarantees
// a stack depth of at most lg(b-a).
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}
}
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
这里代码比较多行,让我们先屏蔽一些细节,只关注思路,大概是这个样子,请留意注释:
func quickSort(data Interface, a, b, maxDepth int) {
//a代表区间起点,b代表区间终点
//在区间规模大于12个元素的时候,用快排
for b-a > 12 {
//每次循环,把递归深度阈值减1
//当递归深度阈值被超过时,用堆排序来解决
//快排操作:
//进行分区,也就是快排的pivot操作
//分区后,得到[a, mlo],pivot,[mhi, b]
//判断 [a, mlo]和[mhi, b],两者规模的大小
//小规模的区间,继续用quickSort
//大规模的区间,还需要继续分区,也就是继续被循环
}
//如果区间的规模小于等于12个元素的时候,会进入这个分支,进行希尔排序
//必要的边界判断不能少
if b-a > 1 {
//希尔排序
}
}
这里我们要把前面的疑问「maxDepth最后return的时候为什么乘于2」解决掉,因为快排是一个分治算法,它总是把区间一分为二(好吧其实是一分为三,但中间的区间不存在递归调用,可以忽略),也就是说每次递归的时候,一个quickSort会调用两次子quickSort,所以要乘于2。
好了,知道了整个排序算法的运行思路,我们可以选择性地深究某些细节。
源码中,希尔排序和堆排序是比较经典的实现(实现得非常好),没有很烧脑的地方,就不在这里展开。
下半篇的重点是剖析快排这一部分的实现,提前说明,先不要纠结细节,按一层一层的思路搞清楚就行。
接下来先看 数据规模大于12个元素的时候的情况,留意注释:
for b-a > 12 { // Use ShellSort for slices <= 12 elements
//忽略堆排
//重点看doPivot的实现,工业级的分区厉害在哪
mlo, mhi := doPivot(data, a, b)
....
}
好的,那我们去理解doPivo的实现。 高能预警,代码非常的长,所以看一眼后,我们来分块探讨,暂且不贴原注释,我们待会分块再贴回来。
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
m := int(uint(lo+hi) >> 1)
//取pivot的一个算法技巧,不要纠结,先过
if hi-lo > 40 {
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
//这里可能会有很多问号??? 还是那句话,先不要纠结。
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
data.Swap(pivot, b-1)
return b - 1, c
}
有没有把你劝退了? 别走,我们先拆再说。
先看如何取中(开篇建议大家先了解快排的思想,不然读到这里的时候就很吃力了),请阅读源码注释和我的注释
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
// 下标是非负整数,所以用uint可以容纳 2*int的正整数,也就避免了整形溢出(想象lo和hi都是巨大的整数)
//lo代表low,低位, hi代表hign,高位
//再仔细点解释,就是区间的起点和终点
m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
//如果区间的规模大于40个元素,取中需要加一些技巧
//这里用到了一个算法,叫John Tukey's median of medians
//给个学习链接:https://www.johndcook.com/blog/2009/06/23/tukey-median-ninther/
//这里注意,三中取中
//老话,别纠结。
if hi-lo > 40 {
// Tukey's ``Ninther,'' median of three medians of three.
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
// Invariants are:
//经过上面的调整,data[lo]此时就是pivot的值
// data[lo] = pivot (set up by ChoosePivot)
这块没问题之后,我们继续分析在下一个代码块,请留意注释:
// Invariants are:
//经过上面的调整,data[lo]此时就是pivot的值
// data[lo] = pivot (set up by ChoosePivot)
//这里告知了我们,分区后的数据集合的区间 a, b, c就是几个分割点,先不要纠结,在下面会有计算a, b, c
//[lo], data[lo] 是pivot
//(lo, a), 这个区间的值 < pivot
//[a, b), 这个区间的值 <= pivot
//[b, c), 这个区间的值 未知
//[c, hi-1), 这个区间的值 > pivot
//[hi-1], data[hi-1] > pivot
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
pivot := lo
//a, c目前是区间的起点和终点,也就两个游标
a, c := lo+1, hi-1
//把a游标,右移到第一个不小于pivot的下标位置上,然后停止
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
继续往下走:
//只要b游标和C游标没有重合,就需要执行pivot的过程
for {
//把b游标往右挪到大于data[pivot]的位置,准备好交换
for ; b < c && !data.Less(pivot, b); b++ {
}
//把c游标往左挪到小于等于data[pivot]的位置,准备好和b交换
for ; b < c && data.Less(pivot, c-1); c-- {
}
//万一挪出了幺蛾子,就作罢
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
//我们知道左边分区要放小于等于pivot的值,右边放大于pivot的值
data.Swap(b, c-1)
//交换后处理游标,继续按各自的原方向移动
b++
c--
}
上面这段好理解,快排无非就是这个思路。但上面也有一个重点信息:遇到了和data[pivot]的值,会停下来并进行交换,如果数据中有大量重复值的元素(和data[povit]相等),那么大量元素会累积在左边的分区。这个分区的时间复杂度就很可能远大于O(nLogn)。
不好理解的在下面这一段代码,我暂时没能找到这段代码的算法来源,但不妨碍我们理解它的功能:
- 测试分区的平衡性
- 调整分区的平衡性
我们本来的想法是:分成 小 pivot 大, 三个区。 但是我们没有考虑到大量重复值的元素。
目前看来最好的方式,大量重复值应该放在pivot附近,最好形成一个「等于」的区间,这样递归的时候,就可以避开中间这个可能很长的区间了。
也就加速了速度。
接下来请着重留意源码注释和我的注解:
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let's be a bit more conservative, and set border to 5.
//发挥想象力!如果数据有大量等于pivot的重复元素,那么分区后,左区就会相对很大,右区反而很小,那么hi-c的值肯定会比较小。
//所以根据源码作者的实际经验?保守估计5是个值得注意的差值
protect := hi-c < 5
//右区间小于5个元素?不行不行,测一下有没有很多重复值
//如果右区占整个区间的25%都不到,那就说明很有重复值太多了了,把protect标记一下,代表需要调整区间
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
//如果需要调整区间,执行如下操作
//把左区间重复的pivot元素,交换到中间位置,这样左区的长度就小了
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
过了这一通厮杀,我们来到最后
// Swap pivot into middle
data.Swap(pivot, b-1)
return b - 1, c
}
终于分区结束,可以返回左区间的终点和右区间的起点,并结束这一个子递归了。
不容易,能坚持看到这的地方的同学一定是:
- 很赏脸
- 很有钻研精神
- 真的找不到好的关于这个主题的文章 (比如我,只好自己分析)
其实说句题外话,这个排序算法也说明了几个现实规律:
- 改善重点部分,能大幅度提高整体效率
- 问题的规模不同,优选的处理方式也会不同
- 前人的知识财产很宝贵,要加以利用
以上分析得不够全面仔细,也肯定存在纰漏瑕疵,希望抛砖引玉,带来帮助。