求数组和结果最大的前k个元素(携程)

80 阅读4分钟
给定两个数组,两个数组的元素求和,输出求和结果最大的前k个元素
例如:
int[] a = {1, 2, 3, 4, 5}
int[] b = {2, 3, 4, 5, 6}
k = 3
结果为:[11,10,10]

解法一:小顶堆

第一思路是想到topK问题,那就需要用到优先级队列即堆这种数据结构了,golang实现堆需要借助container/heap这个包,具体可参考该解析

这里借助小顶堆维护TopK元素和

func main() {
    a := []int{1, 2, 3, 4, 5}
    b := []int{2, 3, 4, 5, 6}
    fmt.Println(FindTopKSum(a, b, 3))
}

type MinHeap []int

func (h MinHeap) Len() int {
    return len(h)
}

func (h MinHeap) Less(i, j int) bool { // 小顶堆,每个节点比孩子节点小
    return h[i] < h[j]
}

func (h MinHeap) Swap(i, j int) {
    h[i], h[j] = h[j], h[i]
}

func (h *MinHeap) Push(x interface{}) {
    *h = append(*h, x.(int))
}

func (h *MinHeap) Pop() interface{} {
    x := (*h)[len(*h)-1]
    *h = (*h)[0 : len(*h)-1]
    return x
}

func FindTopKSum(nums1, nums2 []int, k int) []int {
    if len(nums1) == 0 || len(nums2) == 0 || k <= 0 {
       return nil
    }
    // 初始化小顶堆
    h := &MinHeap{}
    heap.Init(h)

    for i := 0; i < len(nums1); i++ {
       for j := 0; j < len(nums2); j++ {
          sum := nums1[i] + nums2[j]
          // 维护大小为k的小顶堆
          if h.Len() < k { // 堆大小不足k个,直接插入
             heap.Push(h, sum)
          } else if sum > (*h)[0] { // 堆满时,大于堆顶元素(TopK中的最小值)的值才插入
             heap.Pop(h)
             heap.Push(h, sum)
          }
       }
    }

    // 注意小顶堆,堆顶是K中最小值,应该放在结果数组的最后(也可以改用头插法)
    res := make([]int, h.Len())
    for i := len(res) - 1; i >= 0; i-- {
       top := heap.Pop(h)
       res[i] = top.(int)
    }
    return res
}
  • 时间复杂度:O(n^2*logk),双层for循环遍历两个数组所有元素的组合,每次还要若判断要入堆,还需要堆调整的成本O(logk)
  • 空间复杂度:O(k),堆中最多存储 k 个元素

解法二:排序 + 大顶堆

  1. 将两个数组按降序排好序,确保最大的元素在前面
  2. 使用大顶堆,当前堆顶记录的就是当前最大和,堆中存储三元组 (元素和, 元素在数组a的索引, 元素在数组b的索引)
  3. 从最大的和开始(即 a[0]+b[0]),每次弹出堆顶元素(即当前最大和)后,加入结果数组res,然后将候选元素加入堆中(即 (i+1, j) 和 (i, j+1)
  4. 使用哈希表避免重复元素组合(通过下标判断)
func main() {
    a := []int{1, 2, 3, 4, 5}
    b := []int{2, 3, 4, 5, 6}
    fmt.Println(FindTopKSum(a, b, 3))
}

type item struct {
    idx1 int
    idx2 int
    sum  int
}
type MinHeap []item // 堆里存放三元组(两个元素在数组1,2中的下标,以及元素和)

func (h MinHeap) Len() int {
    return len(h)
}

func (h MinHeap) Less(i, j int) bool { // 大顶堆,每个节点比孩子节点大
    return h[i].sum > h[j].sum
}

func (h MinHeap) Swap(i, j int) {
    h[i], h[j] = h[j], h[i]
}

func (h *MinHeap) Push(x interface{}) {
    *h = append(*h, x.(item))
}

func (h *MinHeap) Pop() interface{} {
    x := (*h)[len(*h)-1]
    *h = (*h)[0 : len(*h)-1]
    return x
}

func FindTopKSum(nums1, nums2 []int, k int) []int {
    if len(nums1) == 0 || len(nums2) == 0 || k <= 0 {
       return nil
    }
    // 先对两个数组从大到小排序
    sort.Slice(nums1, func(i, j int) bool {
       return nums1[i] > nums1[j]
    })
    sort.Slice(nums2, func(i, j int) bool {
       return nums2[i] > nums2[j]
    })

    // 初始化小顶堆
    h := &MinHeap{}
    heap.Init(h)

    // 排好序后,最大元素和肯定是nums1[0] + nums2[0]
    heap.Push(h, item{
       idx1: 0,
       idx2: 0,
       sum:  nums1[0] + nums2[0],
    })

    // 备忘录判重
    visited := make(map[[2]int]struct{})
    visited[[2]int{0, 0}] = struct{}{} // nums1[0]和nums2[0]这两个元素已经入堆,记录下标

    res := make([]int, 0)
    for len(res) < k && h.Len() > 0 { // 循环直到找够前K个元素和,或者堆已经为空了(元素和不够K个)
       //fmt.Println("cur heap:", h)
       top := (heap.Pop(h)).(item)
       //fmt.Println("heap pop: ", top)
       res = append(res, top.sum)
       i, j := top.idx1, top.idx2
       // 扩展候选 (i+1, j)
       if i+1 < len(nums1) {
          if _, ok := visited[[2]int{i + 1, j}]; !ok { // 避免重复的索引组合多次加入堆
             newItem := item{
                idx1: i + 1,
                idx2: j,
                sum:  nums1[i+1] + nums2[j],
             }
             //fmt.Println("heap push:", newItem)
             heap.Push(h, newItem)
             visited[[2]int{i + 1, j}] = struct{}{}
          } else {
             //fmt.Printf("duplicate, skip {%d, %d}\n", i+1, j)
          }
       }
       // 扩展候选 (i, j+1)
       if j+1 < len(nums2) {
          if _, ok := visited[[2]int{i, j + 1}]; !ok {
             newItem := item{
                idx1: i,
                idx2: j + 1,
                sum:  nums1[i] + nums2[j+1],
             }
             heap.Push(h, newItem)
             //fmt.Println("heap push:", newItem)
             visited[[2]int{i, j + 1}] = struct{}{}
          } else {
             //fmt.Printf("duplicate, skip {%d, %d}\n", i+1, j)
          }
       }
    }
    return res
}
  • 时间复杂度:O(k*logk),每次堆操作(Push/Pop)为 O(logk),最多执行 k 次
  • 空间复杂度:O(k),堆大小