双轴快排(dual pivot quicksort)

517 阅读2分钟

双轴快速排序。结合了直接插入排序,堆排序,三向切分。

在牛客网提交通过,对应题目

原始论文在这里

JDK实现在这里

// 入口
func Sort(arr []int) {
	l := len(arr)
	if l < 2 {
		return
	}
	sort(arr, 0, 0, l-1)
}

const (
	maxInsertionSortSize = 44
	maxRecursionDepth    = 64
)

// 排序arr的区间[left,right]
func sort(arr []int, depth, left, right int) {
	for {
		size := right - left + 1

		// Invoke insertion sort on small part.
		if size < maxInsertionSortSize {
			insertionSort(arr, left, right)
			return
		}

		depth++
		// Switch to heap sort if execution
		// time is becoming quadratic.
		if depth > maxRecursionDepth {
			heapSort(arr, left, right)
			return
		}

		// Use an inexpensive approximation of the golden ratio
		// to select five sample elements and determine pivots.
		step := (size>>3)*3 + 3
		// Five elements around (and including) the central element
		// will be used for pivot selection as described below. The
		// unequal choice of spacing these elements was empirically
		// determined to work well on a wide variety of inputs.
		e1 := left + step
		e5 := right - step
		e3 := e1 + (e5-e1)>>1
		e2 := e1 + (e3-e1)>>1
		e4 := e3 + (e5-e3)>>1
		// Sort these elements in place by the combination
		// of 4-element sorting network and insertion sort.
		//
		//    5 ------o-----------o------------
		//            |           |
		//    4 ------|-----o-----o-----o------
		//            |     |           |
		//    2 ------o-----|-----o-----o------
		//                  |     |
		//    1 ------------o-----o------------
		if arr[e2] > arr[e5] {
			arr[e2], arr[e5] = arr[e5], arr[e2]
		}
		if arr[e1] > arr[e4] {
			arr[e1], arr[e4] = arr[e4], arr[e1]
		}
		if arr[e4] > arr[e5] {
			arr[e4], arr[e5] = arr[e5], arr[e4]
		}
		if arr[e1] > arr[e2] {
			arr[e1], arr[e2] = arr[e2], arr[e1]
		}
		if arr[e2] > arr[e4] {
			arr[e2], arr[e4] = arr[e4], arr[e2]
		}
		// handle middle element
		a3 := arr[e3]
		if a3 < arr[e2] {
			if a3 < arr[e1] {
				arr[e3] = arr[e2]
				arr[e2] = arr[e1]
				arr[e1] = a3
			} else {
				arr[e3] = arr[e2]
				arr[e2] = a3
			}
		} else if a3 > arr[e4] {
			if a3 > arr[e5] {
				arr[e3] = arr[e4]
				arr[e4] = arr[e5]
				arr[e5] = a3
			} else {
				arr[e3] = arr[e4]
				arr[e4] = a3
			}
		}

		// Partitioning with 2 pivots in case of different elements.
		if arr[e1] < arr[e2] && arr[e2] < arr[e3] && arr[e3] < arr[e4] && arr[e4] < arr[e5] {
			// pivots
			pivot1 := arr[e1]
			arr[e1] = arr[left]
			pivot2 := arr[e5]
			arr[e5] = arr[right]

			// pointers
			var (
				i = left + 1
				k = right - 1
			)

			// Skip elements, which are less or greater than the pivots.
			for arr[i] < pivot1 {
				i++
			}
			for arr[k] > pivot2 {
				k--
			}

			// 3-interval partitioning
			//
			// +-----------------------------------------------------------+
			// |left|  < p1  | |  p1 <= && <= p2  | |  ?  | |  > p2  |right|
			// +-----------------------------------------------------------+
			//                ^                    ^       ^
			//	          |                    |       |
			//	          i                    j       k
			//
			// Invariants:
			//            all in (left, i)     < pivot1
			//  pivot1 <= all in [i, j)       <= pivot2
			//            all in [j, k]       unhandled
			//            all in [k+1, right)  > pivot2
			for j := i; j <= k; j++ {
				n := arr[j]
				if n < pivot1 {
					// arr[i], arr[j] = arr[j], arr[i]
					arr[j] = arr[i]
					arr[i] = n
					i++
				} else if n > pivot2 {
					for j < k && arr[k] > pivot2 {
						k--
					}
					// arr[k], arr[j] = arr[j], arr[k]
					arr[j] = arr[k]
					arr[k] = n
					k--

					if arr[j] < pivot1 {
						arr[i], arr[j] = arr[j], arr[i]
						i++
					}
				}
			}

			// swaps
			arr[left] = arr[i-1]
			arr[i-1] = pivot1
			arr[right] = arr[k+1]
			arr[k+1] = pivot2

			// Sort non-left parts recursively
			sort(arr, depth, i, k)
			sort(arr, depth, k+2, right)

			// Iterate along the left part
			right = i - 2
		} else { // Use single pivot in case of many equal elements
			pivot := arr[e3]
			arr[e3] = arr[left]

			// pointers
			var (
				i = left + 1
				k = right
			)
			// Traditional 3-way (Dutch National Flag) partitioning
			//
			// +-----------------------------------------------------------+
			// |left|  < pivot  | |  == pivot  | |  ?  | |  > pivot  |right|
			// +-----------------------------------------------------------+
			//                   ^              ^       ^
			//                   |              |       |
			//                   i              j       k
			//
			// Invariants:
			//            all in (left, i)    < pivot
			//            all in [i, j)      == pivot
			//            all in [j, k]     unhandled
			//            all in [k+1, right] > pivot
			for j := i; j <= k; j++ {
				n := arr[j]
				if n < pivot {
					// arr[i], arr[j] = arr[j], arr[i]
					arr[j] = arr[i]
					arr[i] = n
					i++
				} else if n > pivot {
					for j < k && arr[k] > pivot {
						k--
					}
					// arr[k], arr[j] = arr[j], arr[k]
					arr[j] = arr[k]
					arr[k] = n
					k--

					if arr[j] < pivot {
						arr[i], arr[j] = arr[j], arr[i]
						i++
					}
				}
			}

			// swap
			arr[left] = arr[i-1]
			arr[i-1] = pivot

			// Sort the right part.
			// All elements from the central part are
			// equal and therefore already sorted.
			sort(arr, depth, k+1, right)

			// Iterate along the left part
			right = i - 2
		}
	}
}

func insertionSort(arr []int, left, right int) {
	for i := left + 1; i <= right; i++ {
		n := arr[i]
		j := i - 1
		for j >= left && arr[j] > n {
			arr[j+1] = arr[j]
			j--
		}
		arr[j+1] = n
	}
}

func heapSort(a []int, left, right int) {
	// 处理堆时使用的循环变量基于0计算
	// 读写数组a的元素时的索引要基于left计算
	// l := right-left+1
	// end := l-1
	end := right - left
	for i := (end - 1) / 2; i >= 0; i-- {
		heapify(a, i, end, left)
	}
	for i := end; i > 0; i-- {
		a[left], a[left+i] = a[left+i], a[left] // a[left]就是a[left+0]
		heapify(a, 0, i-1, left)
	}
}

// 向下调整
func heapify(a []int, b, e, left int) {
	// 处理堆时使用的循环变量基于b计算
	// 读写数组a的元素时的索引要基于left计算
	p := b
	pv := a[left+p]
	for {
		child := 2*p + 1
		if child > e {
			break
		}
		if child+1 <= e && a[left+child+1] > a[left+child] {
			child++
		}
		if pv >= a[left+child] {
			break
		}
		a[left+p] = a[left+child]
		p = child
	}
	a[left+p] = pv
}