优化临时内存分配
- 初始化
- 扩容
优化相邻区间合并
- 先尝试缩小要合并的区间
- 只复制较小的区间,减小临时内存
下面是使用方法:
arr := []int{5,4,3,2,1}
o := NewMergeSort(arr) // 先创建对象,传入待排序的数组
o.Sort() // 排序
下面是算法实现:
type MergeSort struct {
arr []int
tmp []int
}
const (
tinySize = 17
initTmpSize = 256
)
// 创建对象
func NewMergeSort(arr []int) *MergeSort {
var (
l = len(arr)
size = initTmpSize
)
if l < (initTmpSize << 1) {
size = l >> 1
}
return &MergeSort{
arr: arr,
tmp: make([]int, size),
}
}
/**
* 保证临时数组的大小能够容纳所有的临时元素,在需要的时候要扩展临时数组的大小。
* 数组的大小程指数增长,来保证线性的复杂度。
*
* 一次申请步长太小,申请的次数必然会增多,浪费时间;一次申请的空间足够大,必然会
* 浪费空间。正常情况下,归并排序的临时空间每次大的合并都会 * 2,
* 最大长度不会超过数组长度的1/2。 这个长度与2有着紧密的联系。
*
* @param minCapacity 临时数组需要的最小空间
*/
func (o *MergeSort) ensureCapacity(minCapacity int) {
// 如果临时数组长度不够,那需要重新计算临时数组长度
if len(o.tmp) < minCapacity {
// 这里是计算最小的大于minCapacity的2的幂。方法不常见,这里分析一下。
//
// 假设有无符号整型 k,它的字节码如下:
// 00000000 10000000 00000000 00000000 k
// 00000000 11000000 00000000 00000000 k |= k >> 1;
// 00000000 11110000 00000000 00000000 k |= k >> 2;
// 00000000 11111111 00000000 00000000 k |= k >> 4;
// 00000000 11111111 11111111 00000000 k |= k >> 8;
// 00000000 11111111 11111111 11111111 k |= k >> 16
// 上面的移位事实上只跟最高位有关系,移位的结果是最高位往后的bit全部变成了1
// 最后 k++ 的结果 就是刚好是比 minCapacity 大的2的幂
newSize := minCapacity
newSize |= newSize >> 1
newSize |= newSize >> 2
newSize |= newSize >> 4
newSize |= newSize >> 8
newSize |= newSize >> 16
newSize++
if newSize >= 0 {
newSize = min(newSize, len(o.arr)>>1)
} else { // Not bloody likely!
newSize = minCapacity
}
o.tmp = make([]int, newSize)
}
}
// 排序功能入口
func (o *MergeSort) Sort() {
l := len(o.arr)
if l < 2 {
return
}
for seg := 1; seg < l; seg <<= 1 {
o.mergePass(seg)
}
}
func (o *MergeSort) mergePass(seg int) {
var (
l = len(o.arr)
c = (l / seg) >> 1 // 计算有多少对相邻segment需要合并
b int
)
for i := 0; i < c; i++ {
be := b + seg - 1 // 左区间的结束
e := be + seg
o.mergePre(b, be, e)
b = e + 1
}
rest := l - (c<<1)*seg
if rest > seg {
o.mergePre(b, b+seg-1, l-1)
}
}
// 优化一:只复制其中较小的区间,减少临时内存
// b 左区间的开始
// be 左区间的结束
// e 相连的右区间的结束
func (o *MergeSort) merge(b, be, e int) {
if !(b <= be && be < e) {
return
}
var (
j = be + 1 // 右区间的开始
ll = j - b // 左区间长度
rl = e - be // 右区间长度
)
if ll <= rl { // 复制左区间到临时内存,并从小到大合并
o.ensureCapacity(ll)
copy(o.tmp, o.arr[b:j])
var i int // o.tmp的索引
k := b // o.arr的索引
for i < ll && j <= e {
if o.tmp[i] <= o.arr[j] {
o.arr[k] = o.tmp[i]
i++
} else {
o.arr[k] = o.arr[j]
j++
}
k++
}
for i < ll {
o.arr[k] = o.tmp[i]
i++
k++
}
for j <= e {
o.arr[k] = o.arr[j]
j++
k++
}
} else { // 复制右区间到临时内存,并从大到小合并
o.ensureCapacity(rl)
copy(o.tmp, o.arr[j:e+1])
j = rl - 1 // o.tmp的索引
i := be // o.arr左区间的索引
k := e // o.arr的索引
for i >= b && j >= 0 {
if o.tmp[j] >= o.arr[i] {
o.arr[k] = o.tmp[j]
j--
} else {
o.arr[k] = o.arr[i]
i--
}
k--
}
for i >= b {
o.arr[k] = o.arr[i]
i--
k--
}
for j >= 0 {
o.arr[k] = o.tmp[j]
j--
k--
}
}
}
// 优化二:在优化一的基础上,缩小要合并的区间
// b 左区间的开始
// be 左区间的结束
// e 相连的右区间的结束
func (o *MergeSort) mergePre(b, be, e int) {
var (
j = be + 1 // 右区间的开始
ll = j - b // 左区间长度
rl = e - be // 右区间长度
)
// 找到右区间的最小值在左区间按顺序插入的位置
rightMin := o.arr[j]
left := b
if ll < tinySize {
for left < ll {
if o.arr[left] > rightMin {
break
}
left++
}
} else {
left = o.binarySearch(b, be, rightMin, true)
}
if left == j { // 左区间元素值全部小于等于右区间元素值
return
}
// 找到左区间的最大值在右区间按顺序插入的位置
leftMax := o.arr[be]
right := e
if rl < tinySize {
for right >= j {
if o.arr[right] < leftMax {
break
}
right--
}
} else {
right = o.binarySearch(j, e, leftMax, false)
}
if right > e { // 处理越界索引
right = e
}
o.merge(left, be, right)
}
// 在arr的区间[low,high]找n按顺序插入的位置
// rightMost: true表示要尽量向右找
func (o *MergeSort) binarySearch(low, high, n int, rightMost bool) int {
var (
i = low
j = high
)
for i <= j {
m := i + (j-i)>>1
v := o.arr[m]
if rightMost {
if v <= n {
i = m + 1
} else {
j = m - 1
}
} else { // 尽量向左找
if v >= n {
j = m - 1
} else {
i = m + 1
}
}
}
return i
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}