归并排序优化版

176 阅读3分钟

优化临时内存分配

  • 初始化
  • 扩容

优化相邻区间合并

  • 先尝试缩小要合并的区间
  • 只复制较小的区间,减小临时内存

下面是使用方法:

arr := []int{5,4,3,2,1}
o := NewMergeSort(arr) // 先创建对象,传入待排序的数组
o.Sort() // 排序

下面是算法实现:

type MergeSort struct {
    arr []int
    tmp []int
}

const (
    tinySize = 17
    initTmpSize = 256
)

// 创建对象
func NewMergeSort(arr []int) *MergeSort {
    var (
        l    = len(arr)
        size = initTmpSize
    )
    if l < (initTmpSize << 1) {
        size = l >> 1
    }
    return &MergeSort{
        arr: arr,
        tmp: make([]int, size),
    }
}

/**
 * 保证临时数组的大小能够容纳所有的临时元素,在需要的时候要扩展临时数组的大小。
 * 数组的大小程指数增长,来保证线性的复杂度。
 *
 * 一次申请步长太小,申请的次数必然会增多,浪费时间;一次申请的空间足够大,必然会
 * 浪费空间。正常情况下,归并排序的临时空间每次大的合并都会 * 2,
 * 最大长度不会超过数组长度的1/2。 这个长度与2有着紧密的联系。
 *
 * @param minCapacity 临时数组需要的最小空间
 */
func (o *MergeSort) ensureCapacity(minCapacity int) {
    // 如果临时数组长度不够,那需要重新计算临时数组长度
    if len(o.tmp) < minCapacity {
        // 这里是计算最小的大于minCapacity的2的幂。方法不常见,这里分析一下。
        //
        // 假设有无符号整型 k,它的字节码如下:
        // 00000000 10000000 00000000 00000000  k
        // 00000000 11000000 00000000 00000000  k |= k >> 1;
        // 00000000 11110000 00000000 00000000  k |= k >> 2;
        // 00000000 11111111 00000000 00000000  k |= k >> 4;
        // 00000000 11111111 11111111 00000000  k |= k >> 8;
        // 00000000 11111111 11111111 11111111  k |= k >> 16
        // 上面的移位事实上只跟最高位有关系,移位的结果是最高位往后的bit全部变成了1
        // 最后 k++ 的结果 就是刚好是比 minCapacity 大的2的幂
        newSize := minCapacity
        newSize |= newSize >> 1
        newSize |= newSize >> 2
        newSize |= newSize >> 4
        newSize |= newSize >> 8
        newSize |= newSize >> 16
        newSize++
        if newSize >= 0 {
            newSize = min(newSize, len(o.arr)>>1)
        } else { // Not bloody likely!
            newSize = minCapacity
        }
        o.tmp = make([]int, newSize)
    }
}

// 排序功能入口
func (o *MergeSort) Sort() {
    l := len(o.arr)
    if l < 2 {
        return
    }
    for seg := 1; seg < l; seg <<= 1 {
        o.mergePass(seg)
    }
}

func (o *MergeSort) mergePass(seg int) {
    var (
        l = len(o.arr)
        c = (l / seg) >> 1 // 计算有多少对相邻segment需要合并
        b int
    )
    for i := 0; i < c; i++ {
        be := b + seg - 1 // 左区间的结束
        e := be + seg
        o.mergePre(b, be, e)
        b = e + 1
    }
    rest := l - (c<<1)*seg
    if rest > seg {
        o.mergePre(b, b+seg-1, l-1)
    }
}

// 优化一:只复制其中较小的区间,减少临时内存
// b 左区间的开始
// be 左区间的结束
// e 相连的右区间的结束
func (o *MergeSort) merge(b, be, e int) {
    if !(b <= be && be < e) {
        return
    }

    var (
        j  = be + 1 // 右区间的开始
        ll = j - b  // 左区间长度
        rl = e - be // 右区间长度
    )
    if ll <= rl { // 复制左区间到临时内存,并从小到大合并
        o.ensureCapacity(ll)
        copy(o.tmp, o.arr[b:j])
        var i int // o.tmp的索引
        k := b    // o.arr的索引
        for i < ll && j <= e {
            if o.tmp[i] <= o.arr[j] {
                o.arr[k] = o.tmp[i]
                i++
            } else {
                o.arr[k] = o.arr[j]
                j++
            }
            k++
        }
        for i < ll {
            o.arr[k] = o.tmp[i]
            i++
            k++
        }
        for j <= e {
            o.arr[k] = o.arr[j]
            j++
            k++
        }
    } else { // 复制右区间到临时内存,并从大到小合并
        o.ensureCapacity(rl)
        copy(o.tmp, o.arr[j:e+1])
        j = rl - 1 // o.tmp的索引
        i := be    // o.arr左区间的索引
        k := e     // o.arr的索引
        for i >= b && j >= 0 {
            if o.tmp[j] >= o.arr[i] {
                o.arr[k] = o.tmp[j]
                j--
            } else {
                o.arr[k] = o.arr[i]
                i--
            }
            k--
        }
        for i >= b {
            o.arr[k] = o.arr[i]
            i--
            k--
        }
        for j >= 0 {
            o.arr[k] = o.tmp[j]
            j--
            k--
        }
    }
}

// 优化二:在优化一的基础上,缩小要合并的区间
// b 左区间的开始
// be 左区间的结束
// e 相连的右区间的结束
func (o *MergeSort) mergePre(b, be, e int) {
    var (
        j  = be + 1 // 右区间的开始
        ll = j - b  // 左区间长度
        rl = e - be // 右区间长度
    )
    // 找到右区间的最小值在左区间按顺序插入的位置
    rightMin := o.arr[j]
    left := b
    if ll < tinySize {
        for left < ll {
            if o.arr[left] > rightMin {
                break
            }
            left++
        }
    } else {
        left = o.binarySearch(b, be, rightMin, true)
    }

    if left == j { // 左区间元素值全部小于等于右区间元素值
        return
    }

    // 找到左区间的最大值在右区间按顺序插入的位置
    leftMax := o.arr[be]
    right := e
    if rl < tinySize {
        for right >= j {
            if o.arr[right] < leftMax {
                break
            }
            right--
        }
    } else {
        right = o.binarySearch(j, e, leftMax, false)
    }

    if right > e { // 处理越界索引
        right = e
    }

    o.merge(left, be, right)
}

// 在arr的区间[low,high]找n按顺序插入的位置
// rightMost: true表示要尽量向右找
func (o *MergeSort) binarySearch(low, high, n int, rightMost bool) int {
    var (
        i = low
        j = high
    )
    for i <= j {
        m := i + (j-i)>>1
        v := o.arr[m]
        if rightMost {
            if v <= n {
                i = m + 1
            } else {
                j = m - 1
            }
        } else { // 尽量向左找
            if v >= n {
                j = m - 1
            } else {
                i = m + 1
            }
        }
    }
    return i
}

func min(a, b int) int {
    if a <= b {
        return a
    }
    return b
}