【算法与数据结构】算法系列:Top-K问题(Leetcode 215)

314 阅读3分钟

一、题目

leetcode.com/problems/kt…

在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

即,找出数组中的第k大元素,重复的元素算多个。

二、解决方案

1. 直接排序

排序后取第k项即可。 时间复杂度O(NlogN)O(N*logN)

    public int findKthLargest(int[] nums, int k) {
        Arrays.sort(nums);
        return nums[nums.length - k];
    }

2. 使用堆

维护一个大小为k的最小堆,在遍历数组的过程中更新堆,最终根节点即是第k大的数。 这里为了方便阅读,堆使用了面向对象的方式,具体实现参考堆排序。 时间复杂度O(NlogK)O(N*logK)

    public int findKthLargest(int[] nums, int k) {
        TopKHeap heap = new TopKHeap(nums, k);
        heap.buildHeap(); // 建堆
        for (int i = k; i < nums.length; i++) {
            heap.update(nums[i]); // 更新
        }
        return heap.toArray()[0];
    }

其中update方法实现如下:

    public void update(int value) {
        // 如果小于堆中的最小值,那么就丢弃掉
        if (value < heap[0]) return;
        // 从顶部下沉
        heap[0] = value;
        int heapSize = heap.length;
        int pos = 0, left, right;
        while (true) {
            left = left(pos);
            right = left + 1;
            // ---------------------------------
            // 如果pos是三者中最小的,那么退出循环
            // 否则,与left、right之间较小的交换下沉
            // ---------------------------------
            if (left >= heapSize) break;
            if (right >= heapSize) { // 只比较left
                if (heap[pos] < heap[left]) break;
                swap(pos, left);
                pos = left;
            } else { // left、right都比较
                int min = heap[left] > heap[right] ? right : left;
                if (heap[pos] < heap[min]) break;
                swap(pos, min);
                pos = min;
            }
        }
    }

3. 使用快排分区

根据快排分区算法partition,对于一个元素,每一次分区之后,其左侧的所有数小于它,其右侧元素大于它。

    private int partition(int[] nums, int p, int q) {
        int i, j;
        for (i = p, j = p + 1; i <= q && j <= q; j++) {
            if (nums[i] < nums[j]) {
                swap(nums, i + 1, j);
                swap(nums, i, i + 1);
                i++;
            }
        }
        return i;
    }

    public static void swap(int[] arr, int p, int q) {
        int t = arr[p];
        arr[p] = arr[q];
        arr[q] = t;
    }

partition结果

这里第k大是自然语言习惯,是从1开始的;为了符合编码习惯,首先使k = k - 1,以符合编程习惯。

当每次partition结束之后,有:

  1. 如果i恰好等于k,那么它就是第k大的元素;
  2. 如果i小于k,那么就在它右侧的数中寻找;
  3. 如果i大于k,那么就在它左侧的数中寻找;

以上应该很好理解。 根据这个思路可以得到代码:

    public int findKthLargest(int[] nums, int k) {
        k = k - 1; // 将第k大转换为序号第k项(1起始 -> 0起始)
        int p = 0, q = nums.length - 1;
        while (true) {
            int i = partition(nums, p, q);
            if (i - p == k) {
                return nums[i];
            } else if (i - p > k) { // find left
                q = i - 1;
            } else { // find right
                k = k + p - i - 1;
                p = i + 1;
            }
        }
    }

实际表现不佳,主要原因是这个算法虽然最优时间复杂度为O(N)O(N),但是最坏时间复杂度为O(N2)O(N^2)

在参考了官方解答之后,发现了通过随机化降低时间复杂度的方式。虽然理论上最坏实际复杂度仍是O(N2)O(N^2),但是可以通过随机的方式平摊风险,达到实际时间复杂度的降低,将O(N2)O(N^2)的可能性降低到理论上存在。

加入了随机化的代码如下:

    private int partition(int[] nums, int p, int q) {
        randomSwap(nums, p, q);
        int i, j;
        // ……
        return i;
    }

    private Random random = new Random();

    private void randomSwap(int[] nums, int p, int q) {
        int offset = random.nextInt(q - p + 1);
        swap(nums, p, p + offset);
    }

其余部分不变。

完整代码如下,含注释:

    /**
     * 使用从大到小的快排分区之后,对于nums[i]来讲,
     * 其左侧所有数大于它,而右侧所有数小于它。
     * 那么,如果i=k,那么它正好是第k大的;
     * 如果i<k,那么在右侧寻找第k-i大的数;
     * 如果i>k,那么在左侧寻找第k大的数。
     */
    public int findKthLargest(int[] nums, int k) {
        k = k - 1; // 将第k大转换为排序
        int p = 0, q = nums.length - 1;
        while (true) {
            int i = partition(nums, p, q);
            if (i - p == k) {
                return nums[i];
            } else if (i - p > k) {
                // find left
                q = i - 1;
            } else {
                // find right
                k = k + p - i - 1;
                p = i + 1;
            }
        }
    }

    private int partition(int[] nums, int p, int q) {
        randomSwap(nums, p, q);
        int i, j;
        for (i = p, j = p + 1; i <= q && j <= q; j++) {
            if (nums[i] < nums[j]) {
                swap(nums, i + 1, j);
                swap(nums, i, i + 1);
                i++;
            }
        }
        return i;
    }

    private Random random = new Random();

    private void randomSwap(int[] nums, int p, int q) {
        int offset = random.nextInt(q - p + 1);
        swap(nums, p, p + offset);
    }

    public static void swap(int[] arr, int p, int q) {
        int t = arr[p];
        arr[p] = arr[q];
        arr[q] = t;
    }