看懂工业级快速排序,打动面试官

710 阅读13分钟

面试中,我们遇到很多排序算法的考核。

快速排序,堆排,希尔排序,插入排序,归并排序等等...

由于现在竞争激烈,面试官除了问这些经典的排序算法,偶尔也会考核一下候选人对很多工业排序的认知,比如MySQL的order by,PHP的array系列排序函数,Go的sort包,Java的Collections.sort()。

对排序这块,确实让人头疼,存在这么几个痛处:

  • 找不到比较规范的算法
  • 很难完全理解规范的算法
  • 好不容易学会了又会忘记
  • 和工作中的内容不搭边
  • 由于上面几点,面试的排序怎么都准备不踏实

这里我介绍一个比较实用的排序学习方向,也是今天要探讨的一部分:

算法知识 -> 普通练手 -> 工业源码 -> 工作场景排序

  • 算法知识:特定排序对应的知识,比如插入排序的规范写法(教材),原理等
  • 普通练手:理解了算法后,在编辑器中敲出来并运行,配合LeetCode之类
  • 工业源码:去了解你熟悉的语言包如何实现排序,因为它们经受住了工业级别的应用,显然非常有价值
  • 工作场景:现实场景中的排序,比如1000万用户的多维度排序,限定内存的多文件大文件排序等

只要踏踏实实掌握上面4点,排序算法这块就可以安心了,即使你会忘记一些细节,你在跟别人讲述排序的时候,别人都可以感受到你「由点及面」的能力。

今天我们就第三点「工业源码」的排序展开学习,由于我最近使用Go比较多,也比较生疏,正好就拿它的排序包sort来学习了。

读源码是一项脑力+体力活,难度确实不小。

作为工程师,「拆解问题」是我们的本能,对于今天这个特殊的主题:「Go语言Sort包中的工程级排序算法」,我们将它拆解成不同粒度的子问题,一直到可以理解。

我会从入口开始,一层一层分析,直接最后理解了整个排序,希望也能帮助到你。

开始:

从使用层面的粒度:

// Sort sorts data.
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
	n := data.Len()
	quickSort(data, 0, n, maxDepth(n))
}

注释里写到:Sort不是一个稳定排序算法,也就是说排序后有可能会破坏数据的稳定性,比如 [5, 1, 2, 3, 5],经过不稳定排序后,有可能最后一个元素5会跑到第一个5的前面,变成[1, 2, 3, 5(原来的最后一个), 5(原来的第一个)]

本篇关注的重点是工程级的排序算法如何实现,所以假设各位都清楚了排序相关的一些前置知识,包括:时间复杂度分析,插入排序,希尔排序,堆排序,快速排序和Go语言的基本语法,如果还不是很清楚这些知识点的话,建议先逐个击破,再回到该篇一起探讨这个算法,会更高效。

说明一点,源码的注释我都保留了,以免因为我的片面理解歪曲了作者的本意。

进入正题,先看一下Interface的结构:

// A type, typically a collection, that satisfies sort.Interface can be
// sorted by the routines in this package. The methods require that the
// elements of the collection be enumerated by an integer index.
type Interface interface {
	// Len is the number of elements in the collection.
	Len() int
	// Less reports whether the element with
	// index i should sort before the element with index j.
	Less(i, j int) bool
	// Swap swaps the elements with indexes i and j.
	Swap(i, j int)
}

经典三大配件,无需多言。

熟悉排序的我们知道,排序过程中,比较和交换是必要环节,比较可以判断是否需要交换,而交换会减少逆序度(人话:数据集合会变得有序一些)。

了解了data要求的结构之后,我们就可以进入Sort这个入口函数了,但在这提醒一点,Less()的实现完全是可以由我们来自定义的,所以对于复杂对象的比较,比如 注册时间+用户年龄的多维度排序,我们可以自定义Less()。

往下走:

入口函数告诉我们,会调用一次data.Len来获取数据的长度,并且以O(n*log(n))的规模来调用Less和Swap,像这种说明,先看一眼,不去细究。

重点来了:

quickSort(data, 0, n, maxDepth(n))

这个 maxDepth(n)是干什么用的? 我们先大概看一眼它的功能,不要太深究背后的数学原理。

根据数据规模,不断二分,累计深度。

// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
	var depth int
	for i := n; i > 0; i >>= 1 {
		depth++
	}
	//下面这行不要纠结,后面会解释
	return depth * 2
}

注释告诉我们:maxDepth返回一个适当的阈值,用于判断当前的排序算法应不应该由快排切换到堆排,这里maxDepth的作用就是限制递归的深度。递归的空间成本很高,面对巨大规模的数据,递归还有内存溢出风险,maxDepth的存在规避了风险。 这里我们可能会好奇最后return的时候为什么乘于2,没事,后面进入快排的时候,我会解释一下。

看完maxDepth,我们回到quickSort。

func quickSort(data Interface, a, b, maxDepth int) {
	for b-a > 12 { // Use ShellSort for slices <= 12 elements
		if maxDepth == 0 {
			heapSort(data, a, b)
			return
		}
		maxDepth--
		mlo, mhi := doPivot(data, a, b)
		// Avoiding recursion on the larger subproblem guarantees
		// a stack depth of at most lg(b-a).
		if mlo-a < b-mhi {
			quickSort(data, a, mlo, maxDepth)
			a = mhi // i.e., quickSort(data, mhi, b)
		} else {
			quickSort(data, mhi, b, maxDepth)
			b = mlo // i.e., quickSort(data, a, mlo)
		}
	}
	if b-a > 1 {
		// Do ShellSort pass with gap 6
		// It could be written in this simplified form cause b-a <= 12
		for i := a + 6; i < b; i++ {
			if data.Less(i, i-6) {
				data.Swap(i, i-6)
			}
		}
		insertionSort(data, a, b)
	}
}

这里代码比较多行,让我们先屏蔽一些细节,只关注思路,大概是这个样子,请留意注释:

func quickSort(data Interface, a, b, maxDepth int) {
    //a代表区间起点,b代表区间终点
    //在区间规模大于12个元素的时候,用快排
	for b-a > 12 { 
	    //每次循环,把递归深度阈值减1
		//当递归深度阈值被超过时,用堆排序来解决
		
		//快排操作:
		//进行分区,也就是快排的pivot操作
	    //分区后,得到[a, mlo],pivot,[mhi, b]
		//判断 [a, mlo]和[mhi, b],两者规模的大小
	
	    //小规模的区间,继续用quickSort
	    //大规模的区间,还需要继续分区,也就是继续被循环
	}
	
	//如果区间的规模小于等于12个元素的时候,会进入这个分支,进行希尔排序
	//必要的边界判断不能少
	if b-a > 1 {
        //希尔排序
	}
}

这里我们要把前面的疑问「maxDepth最后return的时候为什么乘于2」解决掉,因为快排是一个分治算法,它总是把区间一分为二(好吧其实是一分为三,但中间的区间不存在递归调用,可以忽略),也就是说每次递归的时候,一个quickSort会调用两次子quickSort,所以要乘于2。

好了,知道了整个排序算法的运行思路,我们可以选择性地深究某些细节。

源码中,希尔排序和堆排序是比较经典的实现(实现得非常好),没有很烧脑的地方,就不在这里展开。

下半篇的重点是剖析快排这一部分的实现,提前说明,先不要纠结细节,按一层一层的思路搞清楚就行。

接下来先看 数据规模大于12个元素的时候的情况,留意注释:

	for b-a > 12 { // Use ShellSort for slices <= 12 elements
		//忽略堆排
		//重点看doPivot的实现,工业级的分区厉害在哪
		mlo, mhi := doPivot(data, a, b)
		....
	}

好的,那我们去理解doPivo的实现。 高能预警,代码非常的长,所以看一眼后,我们来分块探讨,暂且不贴原注释,我们待会分块再贴回来。

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
	m := int(uint(lo+hi) >> 1) 
	//取pivot的一个算法技巧,不要纠结,先过
	if hi-lo > 40 {
		s := (hi - lo) / 8
		medianOfThree(data, lo, lo+s, lo+2*s)
		medianOfThree(data, m, m-s, m+s)
		medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
	}
	medianOfThree(data, lo, m, hi-1)

	pivot := lo

	a, c := lo+1, hi-1

	for ; a < c && data.Less(a, pivot); a++ {
	}
	b := a

	for {
		for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
		}
		for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
		}
		if b >= c {
			break
		}
		// data[b] > pivot; data[c-1] <= pivot
		data.Swap(b, c-1)
		b++
		c--
	}

    //这里可能会有很多问号??? 还是那句话,先不要纠结。
	protect := hi-c < 5
	if !protect && hi-c < (hi-lo)/4 {
		dups := 0
		if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
			data.Swap(c, hi-1)
			c++
			dups++
		}
		if !data.Less(b-1, pivot) { // data[b-1] = pivot
			b--
			dups++
		}

		if !data.Less(m, pivot) { // data[m] = pivot
			data.Swap(m, b-1)
			b--
			dups++
		}
		// if at least 2 points are equal to pivot, assume skewed distribution
		protect = dups > 1
	}
	if protect {
		for {
			for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
			}
			for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
			}
			if a >= b {
				break
			}
			// data[a] == pivot; data[b-1] < pivot
			data.Swap(a, b-1)
			a++
			b--
		}
	}
	data.Swap(pivot, b-1)
	return b - 1, c
}

有没有把你劝退了? 别走,我们先拆再说。

先看如何取中(开篇建议大家先了解快排的思想,不然读到这里的时候就很吃力了),请阅读源码注释和我的注释

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {

    // 下标是非负整数,所以用uint可以容纳 2*int的正整数,也就避免了整形溢出(想象lo和hi都是巨大的整数)
    //lo代表low,低位, hi代表hign,高位
    //再仔细点解释,就是区间的起点和终点
    m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
    
    //如果区间的规模大于40个元素,取中需要加一些技巧
    //这里用到了一个算法,叫John Tukey's median of medians
    //给个学习链接:https://www.johndcook.com/blog/2009/06/23/tukey-median-ninther/
    //这里注意,三中取中
    //老话,别纠结。
	if hi-lo > 40 {
		// Tukey's ``Ninther,'' median of three medians of three.
		s := (hi - lo) / 8
		medianOfThree(data, lo, lo+s, lo+2*s)
		medianOfThree(data, m, m-s, m+s)
		medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
	}
	medianOfThree(data, lo, m, hi-1)

	// Invariants are:
    //经过上面的调整,data[lo]此时就是pivot的值
	//	data[lo] = pivot (set up by ChoosePivot)

这块没问题之后,我们继续分析在下一个代码块,请留意注释:

// Invariants are:
     //经过上面的调整,data[lo]此时就是pivot的值
	//	data[lo] = pivot (set up by ChoosePivot)
	
	//这里告知了我们,分区后的数据集合的区间 a, b, c就是几个分割点,先不要纠结,在下面会有计算a, b, c
	//[lo],   data[lo] 是pivot
	//(lo, a), 这个区间的值 < pivot
	//[a, b),  这个区间的值 <= pivot
	//[b,  c),  这个区间的值 未知
	//[c, hi-1),  这个区间的值 > pivot
	//[hi-1], data[hi-1] > pivot
	
	//	data[lo < i < a] < pivot
	//	data[a <= i < b] <= pivot
	//	data[b <= i < c] unexamined
	//	data[c <= i < hi-1] > pivot
	//	data[hi-1] >= pivot
	
	
	pivot := lo
	//a, c目前是区间的起点和终点,也就两个游标
	a, c := lo+1, hi-1

    //把a游标,右移到第一个不小于pivot的下标位置上,然后停止
	for ; a < c && data.Less(a, pivot); a++ {
	}
	b := a

继续往下走:

//只要b游标和C游标没有重合,就需要执行pivot的过程
	for {
		//把b游标往右挪到大于data[pivot]的位置,准备好交换
		for ; b < c && !data.Less(pivot, b); b++ { 
		}
		//把c游标往左挪到小于等于data[pivot]的位置,准备好和b交换
		for ; b < c && data.Less(pivot, c-1); c-- { 
		}
		//万一挪出了幺蛾子,就作罢
		if b >= c {
			break
		}
		// data[b] > pivot; data[c-1] <= pivot
		//我们知道左边分区要放小于等于pivot的值,右边放大于pivot的值
		data.Swap(b, c-1)
		//交换后处理游标,继续按各自的原方向移动
		b++
		c--
	}

上面这段好理解,快排无非就是这个思路。但上面也有一个重点信息:遇到了和data[pivot]的值,会停下来并进行交换,如果数据中有大量重复值的元素(和data[povit]相等),那么大量元素会累积在左边的分区。这个分区的时间复杂度就很可能远大于O(nLogn)。

不好理解的在下面这一段代码,我暂时没能找到这段代码的算法来源,但不妨碍我们理解它的功能:

  • 测试分区的平衡性
  • 调整分区的平衡性

我们本来的想法是:分成 小 pivot 大, 三个区。 但是我们没有考虑到大量重复值的元素。

目前看来最好的方式,大量重复值应该放在pivot附近,最好形成一个「等于」的区间,这样递归的时候,就可以避开中间这个可能很长的区间了。

也就加速了速度。

接下来请着重留意源码注释和我的注解:

// If hi-c<3 then there are duplicates (by property of median of nine).
	// Let's be a bit more conservative, and set border to 5.
//发挥想象力!如果数据有大量等于pivot的重复元素,那么分区后,左区就会相对很大,右区反而很小,那么hi-c的值肯定会比较小。
//所以根据源码作者的实际经验?保守估计5是个值得注意的差值

	protect := hi-c < 5
	//右区间小于5个元素?不行不行,测一下有没有很多重复值
	//如果右区占整个区间的25%都不到,那就说明很有重复值太多了了,把protect标记一下,代表需要调整区间
	if !protect && hi-c < (hi-lo)/4 {
		// Lets test some points for equality to pivot
		dups := 0
		if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
			data.Swap(c, hi-1)
			c++
			dups++
		}
		if !data.Less(b-1, pivot) { // data[b-1] = pivot
			b--
			dups++
		}
		// m-lo = (hi-lo)/2 > 6
		// b-lo > (hi-lo)*3/4-1 > 8
		// ==> m < b ==> data[m] <= pivot
		if !data.Less(m, pivot) { // data[m] = pivot
			data.Swap(m, b-1)
			b--
			dups++
		}
		// if at least 2 points are equal to pivot, assume skewed distribution
		protect = dups > 1
	}
	
    //如果需要调整区间,执行如下操作
    //把左区间重复的pivot元素,交换到中间位置,这样左区的长度就小了
	if protect {
		// Protect against a lot of duplicates
		// Add invariant:
		//	data[a <= i < b] unexamined
		//	data[b <= i < c] = pivot
		for {
			for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
			}
			for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
			}
			if a >= b {
				break
			}
			// data[a] == pivot; data[b-1] < pivot
			data.Swap(a, b-1)
			a++
			b--
		}
	}

过了这一通厮杀,我们来到最后

// Swap pivot into middle
	data.Swap(pivot, b-1)
	return b - 1, c
}

终于分区结束,可以返回左区间的终点和右区间的起点,并结束这一个子递归了。

不容易,能坚持看到这的地方的同学一定是:

  • 很赏脸
  • 很有钻研精神
  • 真的找不到好的关于这个主题的文章 (比如我,只好自己分析)

其实说句题外话,这个排序算法也说明了几个现实规律:

  • 改善重点部分,能大幅度提高整体效率
  • 问题的规模不同,优选的处理方式也会不同
  • 前人的知识财产很宝贵,要加以利用

以上分析得不够全面仔细,也肯定存在纰漏瑕疵,希望抛砖引玉,带来帮助。