基于三向切分的内省排序(IntroSort)

349 阅读2分钟

在牛客网提交通过,对应题目

// 入口
func Sort(a []int) {
    l := len(a)
    if l < 2 {
        return
    }
    sort(a, 0, l-1, getMaxDepth(l))
}

// 控制递归最大深度
func getMaxDepth(n int) int {
    var depth int
    for i := n; i > 0; i >>= 1 {
        depth++
    }
    return 2*depth
}

// 递归调用
func sort(a []int, low, high, depth int) {
    // high-low+1 > 12
    if high-low > 11 {
        if depth == 0 {
            // 达到最大递归深度改为使用堆排序
            heapSort(a, low, high)
        } else {
            depth--
            pl, pr := partition(a, low, high)
            sort(a, low, pl-1, depth)
            sort(a, pr+1, high, depth)
        }
    } else {
        // 小区间使用插入排序
        insertionSort(a, low, high)
    }
}

func heapSort(a []int, low, high int) {
    // 处理堆时使用的循环变量基于0计算
    // 读写数组a的元素时的索引要基于low计算
    // l := high-low+1
    // end := l-1
    end := high-low
    for i := (end-1)/2; i >= 0; i-- {
        heapify(a, i, end, low)
    }
    for i := end; i > 0; i-- {
        a[low], a[low+i] = a[low+i], a[low] // a[low]就是a[low+0]
        heapify(a, 0, i-1, low)
    }
}

// 向下调整
func heapify(a []int, b, e, low int) {
    // 处理堆时使用的循环变量基于b计算
    // 读写数组a的元素时的索引要基于low计算
    p := b
    pv := a[low+p]
    for {
        child := 2*p + 1
        if child > e {
            break
        }
        if child+1 <= e && a[low+child+1] > a[low+child] {
            child++
        }
        if pv >= a[low+child] {
            break
        }
        a[low+p] = a[low+child]
        p = child
    }
    a[low+p] = pv
}

func insertionSort(a []int, low, high int) {
    for i := low+1; i <= high; i++ {
        n := a[i]
        j := i-1
        for j >= low && a[j] > n {
            a[j+1] = a[j]
            j--
        }
        a[j+1] = n
    }
}

// 区间索引范围 [low, high]
// 三向切分。相关问题是“荷兰国旗问题”
// 返回值:pivot的值在切分完成后的最左索引和最右索引
func partition(a []int, low, high int) (int, int) {
	selectPivot(a, low, high)
	pivot := a[low]
	/*
	   [low, i)    <pivot
	   [i,   j)    ==pivot
	   [j,   k]    未处理
	   [k+1, high] >pivot
	*/
	var (
		i = low
		j = low
		k = high
	)
	for j <= k {
		n := a[j]
		if n < pivot {
			//a[i], a[j] = n, a[i]
			a[j] = a[i]
			a[i] = n
			i++
			j++
		} else if n > pivot {
			//a[j], a[k] = a[k], n
			a[j] = a[k]
			a[k] = n
			k--
		} else {
			j++
		}
	}
	return i, j-1
}

// 将选出的枢轴元素放在low索引上
func selectPivot(a []int, low, high int) {
    l := high-low+1
    mid := low + (high-low)>>1
    // l <= 40,三数取中
    // l > 40,采用九数取中
    if l > 40 {
        seg := l/8
        medianOfThree(a, low, low+seg, low+2*seg)
        medianOfThree(a, mid, mid-seg, mid+seg)
        medianOfThree(a, high, high-seg, high-2*seg)
    }
    medianOfThree(a, low, mid, high)
}

// 将i1,i0,i2三个索引位置上的数的中位数交换到i1索引上
func medianOfThree(a []int, i1, i0, i2 int) {
    if a[i1] > a[i2] {
        a[i1], a[i2] = a[i2], a[i1]
    }
    if a[i0] > a[i2] {
        a[i0], a[i2] = a[i2], a[i0]
    }
    if a[i1] < a[i0] {
        a[i1], a[i0] = a[i0], a[i1]
    }
}