347. 前 K 个高频元素

440 阅读2分钟

思路

维护一个大小为K的小顶堆,堆内的排序原则是元素在原数组中出现的频率。

  • 堆内放的是“元素”,不重复。
  • 先扫一遍keySet,把前K个元素放进去。
  • 之后从第K+1个元素继续扫,遇到比堆顶大的就置换,没有遇到就继续扫。
  • 最后堆中的元素即为前K个高频元素

升华

  • topk (前k大)用小根堆,维护堆大小不超过 k 即可。每次压入堆前和堆顶元素比较,如果比堆顶元素还小,直接扔掉,否则压入堆。检查堆大小是否超过 k,如果超过,弹出堆顶。复杂度是 nlogk。避免使用大根堆,因为你得把所有元素压入堆,复杂度是 nlogn,而且还浪费内存。如果是海量元素,那就挂了。

  • 求前 k 大,用小根堆,求前 k 小,用大根堆。

0. 快速选择, 平均时间复杂度O(N) 最差O(N2)

class Solution {
    Map<Integer, Integer> map;
    public int[] topKFrequent(int[] nums, int k) {
        // 统计频率
        map = new HashMap<>();
        for (int x : nums) {
            map.put(x, map.getOrDefault(x, 0) + 1);
        }
        // 创建无重复元素的数组
        int[] newNums = new int[map.size()];
        int i = 0;
        for (int x : map.keySet()) {
            newNums[i] = x;
            i++;
        }
        // 快速选择,topKFNum是第k大频率的数字
        quickSelect(newNums, 0, newNums.length - 1, k);// 传入新的数组
        int topKFNum = newNums[newNums.length - k];
        // 填充结果
        int[] res = new int[k];
        i = 0;
        for (int x : newNums) {
            if (map.get(x) >= map.get(topKFNum)) {
                res[i] = x;
                i++;
            }
        }
        return res;
    }

    //快速选择
    public void quickSelect(int[] nums, int start, int end, int k) {
        if (start > end) {
            return;
        }
        int i = start, j = end, pivot = start;
        while (i != j) {
            while (map.get(nums[j]) > map.get(nums[pivot])) {// 比较的是频率
                j--;
            }
            while (i < j && map.get(nums[i]) <= map.get(nums[pivot])) {
                i++;
            }
            swap(nums, i, j);
        }
        swap(nums, i, pivot);
        // 按照词频规则排序后,这里判断的是index,而不是频率
        if (i == nums.length - k) {
            return;
        } else if (i < nums.length - k) {
            quickSelect(nums, i + 1, end, k);
        } else {
            quickSelect(nums, start, i - 1, k);
        }
    }

    public void swap(int[] nums, int m, int n) {
        int tmp = nums[m];
        nums[m] = nums[n];
        nums[n] = tmp;
    }
}

1. 使用 Java PriorityQueue 作为小顶堆 O(NlogK)

class Solution {
    public int[] topKFrequent(int[] nums, int k) {
        Map<Integer, Integer> map = new HashMap<>();
        //遍历数组,记录频率
        for (int x : nums) {
            map.put(x, map.getOrDefault(x, 0) + 1);
        }
        //(a, b) -> map.get(a) - map.get(b)表示:堆内元素排序原则按照原数组中他的出现频率
        Queue<Integer> pq = new PriorityQueue<>((a, b) -> map.get(a) - map.get(b));
        for (int x : map.keySet()) {// 注意不能遍历nums,会有重复
            if (pq.size() < k) {// 前k个
                pq.add(x);
            } else {//继续从K+1开始扫
                if (map.get(x) > map.get(pq.peek())) {
                    pq.poll();
                    pq.add(x);
                }
            }
        }
        //堆中元素即为前K个高频元素
        int[] res = new int[k];
        for (int i = 0; i < k; i++) {
            res[i] = pq.poll();
        }
        return res;
    }
}

2. 手写小顶堆的build

class Solution {
    Map<Integer, Integer> map = new HashMap<>();
    public int[] topKFrequent(int[] nums, int k) {
        //统计频率,建HashMap
        for (int x : nums) {
            map.put(x , map.getOrDefault(x, 0) + 1);
        }
        //初始化一个数组,大小为元素个数(不重复的),因为我们要按照频率来排序
        int[] arr = new int[map.keySet().size()];//buildHeap应该build去充值后的
        int j = 0;
        //把不重复的“元素”放到arr中
        for (int x : map.keySet()) {
            arr[j] = x;
            j++;
        }
        //先把前K个元素给build
        buildMinHeap(arr, k);
        //从第k+1个元素开始扫,遇到频率更大的元素就置换堆顶并重新heapify,没有就继续
        for (int i = k; i < arr.length; i++) {
            if (map.get(arr[i]) > map.get(arr[0])) {
                swap(arr, 0, i);
                minHeapify(arr, 0, k);
            } 
        }
        //堆中元素即为前K个高频元素
        int[] res = new int[k];
        for (int i = 0; i < k; i++) {
            res[i] = arr[i];
        }
        return res;
    }
    
    public void buildMinHeap(int[] nums, int size) {
        int lastNode = size - 1;
        int lastParent = (lastNode - 1) / 2;
        for (int i = lastParent; i >= 0; i--) {
            minHeapify(nums, i, size);
        }        
    }
    
    public void minHeapify(int[] nums, int i, int n) {
        int c1 = 2 * i + 1;
        int c2 = 2 * i + 2;
        int min = i;
        //这里排序标准是频率
        if (c1 < n && map.get(nums[c1]) < map.get(nums[min])) {
            min = c1;
        }
        //这里排序标准是频率
        if (c2 < n && map.get(nums[c2]) < map.get(nums[min])) {
            min = c2;
        }
        if (min != i) {
            swap(nums, min, i);
            minHeapify(nums, min, n);
        }
    }
    
    public void swap(int[] nums, int i, int j) {
        int tmp = nums[i];
        nums[i] = nums[j];
        nums[j] = tmp;
    }
}