BFPRT算法(Median of medians)

885 阅读1分钟

在一大堆数中求其前k大或前k小的问题,简称TOP-K问题。目前解决TOP-K问题最有效的算法是BFPRT算法,又称为中位数的中位数算法,该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,最坏时间复杂度为O(N)。

// 找到第k小的数
// 参数:
//     k >= 1
// 返回值:
//     第k小的数
// panic:
//     k无效时,panic
func SearchKthSmallNum(a []int, k int) int {
    if k < 1 {
        panic("k < 1")
    }
    l := len(a)
    if l < k {
        panic("len(a) < k")
    }
    i := searchKth(a, 0, l-1, k)
    return a[i]
}

// 参数:
//    k >= 1
//    数组索引区间[low,high]
// 返回值:
//     pivotIndex:第k小元素的索引
func searchKth(a []int, low, high, k int) int {
    for low < high {
        pivotIndex := selectPivot(a, low, high)
        pivotIndex = partition(a, low, high, pivotIndex, k)
        if k-1 < pivotIndex {
            high = pivotIndex - 1
        } else if k-1 == pivotIndex {
            return pivotIndex
        } else {
            low = pivotIndex + 1
        }
    }
    return low
}

// 选取枢轴。median of medians
// 参数:
//     数组索引区间[low,high]
// 返回值:
//     枢轴所在索引
func selectPivot(a []int, low, high int) int {
    // 小于或等于5的区间简化处理
    // high-low+1 <= 5
    // high-low <= 4
    if high-low < 5 {
        return partition5(a, low, high)
    }
    t := low
    for i := low; i <= high; i += 5 {
        j := i + 4
        if j > high {
            j = high
        }
        m := partition5(a, i, j)
        t = low + (i-low)/5
        a[m], a[t] = a[t], a[m]
    }
    mid := low + (t-low)/2
    if (t-low+1)&1 == 0 { // 上面得到的一组中位数所在的区间长度是偶数
        mid++ // 目的是取下中位数
    }
    return searchKth(a, low, t, mid)
}

// 返回索引区间[low,high]的上中位数的索引
func partition5(a []int, low, high int) int {
    // 基于InsertionSort
    for i := low + 1; i <= high; i++ {
        n := a[i]
        j := i - 1
        for j >= low && a[j] > n {
            a[j+1] = a[j]
            j--
        }
        a[j+1] = n
    }
    return low + (high-low)>>1
}

// 三向切分。a three-way partition
// 根据pivotIndex索引上是数来切分索引区间[low, high]。
// 根据kth返回不同的值,供调用方判断在哪个区间继续查找。
func partition(a []int, low, high, pivotIndex, kth int) int {
    var (
        i     = low
        j     = low
        k     = high
        pivot = a[pivotIndex]
    )
    // 数组a有3个索引i,j,k, 且i<=j<=k,
    // [low, i)区间的数组元素都小于pivotVal
    // [i, j)区间的数组元素都等于pivotVal
    // [j, k]区间的数组元素是未处理的
    // [k+1, high]区间的元素都大于pivotVal
    for j <= k {
        n := a[j]
        if n < pivot {
            //a[i], a[j] = a[j], a[i]
            a[j] = a[i]
            a[i] = n
            i++
            j++
        } else if n > pivot {
            //a[k], a[j] = a[j], a[k]
            a[j] = a[k]
            a[k] = n
            k--
        } else {
            j++
        }
    }
    if kth-1 < i {
        return i
    } else if kth-1 < j {
        return kth - 1
    } else {
        return k
    }
}