第 K 小的数(快排 + 堆排)

604 阅读2分钟

第 K 小的数(快排 + 堆排)

题目

leetcode-cn.com/problems/kt…

给定整数数组 nums 和整数 k ,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

快速排序

快速排序是基于分治思想的算法,时间复杂度:O(NlogN) ,空间复杂度:O(1)

快速排序代码

public void q(int[] nums, int l, int r) {
    // 本次排序数组只有一个元素,不需要排序了,直接退出
    if (l == r) return;
    int left = l - 1, right = r + 1;
    // 基准值为数组中间元素
    int p = nums[l + r >> 1];
    while (left < right) {
        // 找到左边第一个大于基准值的位置
        do left++; while (nums[left] < p);
        // 找到右边第一个小于基准值的位置
        do right--; while (nums[right] > p);
        // 交换 left 和 righ 位置元素
        if (left < right) {
            int tmp = nums[left];
            nums[left] = nums[right];
            nums[right] = tmp;
        }
    }
    // 每次 while 循环结束后,right <= left
    q(nums, l, right);     // 右边元素进行排序
    q(nums, right + 1, r); // 左边元素进行排序
}

快排倒序排序

我们可以直接利用快速排序(倒序),直接返回 nums[k - 1] 位置的元素,时间复杂度为 O(NlogN)

class Solution {
    public void q(int[] nums, int l, int r) {
        // 本次排序数组只有一个元素,不需要排序了,直接退出
        if (l == r) return;
        int left = l - 1, right = r + 1;
        // 基准值为数组中间元素
        int p = nums[l + r >> 1];
        while (left < right) {
            // 倒序排序,左边找大于基准的值,右边找小于基准的值
            do left++; while (nums[left] > p);
            do right--; while (nums[right] < p);
            // 交换 left 和 righ 位置元素
            if (left < right) {
                int tmp = nums[left];
                nums[left] = nums[right];
                nums[right] = tmp;
            }
        }
        // 每次 while 循环结束后,right <= left
        q(nums, l, right);     // 右边元素进行排序
        q(nums, right + 1, r); // 左边元素进行排序
    }
    public int findKthLargest(int[] nums, int k) {
        int n = nums.length;
        q(nums, 0, n - 1);
        return nums[k - 1];
    }
}

基于快排优化

每次排序一遍后,第 K 大数总是在 right 左边([l, right])或右边([right + 1, r]),因此我们只需要对左边或右边其中一个序列排序即可,依次递归即可

如何计算第 K 大数在左边区间还是右边区间呢?

每次我们只需要判断 K 是否小于等于 左边区间元素个数(right - l + 1)

  1. 如果 K <= 左边区间元素个数 说明第 K 大数在左边区域,下一次递归左边界为:l右边界为:right第 K 大数为K

  2. 如果 K > 左边区间元素个数 说明第 K 大数在右边区域,下一次递归左边界为:right + 1右边界为:r第 K 大数为K - 左边区间元素个数

注意:第二种情况计算 K 时,不再是原来的 K 了,需要计算出在新区间([right + 1, r])中是第几大的数,进行下一次递归

代码

class Solution {
    public void q(int nums[], int k, int l, int r) {
        if (l == r) return;
        int left = l - 1, right = r + 1, p = nums[l + r >> 1];
        while (left < right) {
            // 采用倒序排序
            do left++; while (nums[left] > p);
            do right--; while (nums[right] < p);
            if (left < right) {
                int tmp = nums[left];
                nums[left] = nums[right];
                nums[right] = tmp;
            }
        }
        int ln = right - l + 1; // 计算出左边区域元素个数
        if (k <= ln) q(nums, k, l, right);  // 说明 K 在左边区间
        else q(nums, k - ln, right + 1, r); // K 在右边区间,计算出新的 K
    }
    public int findKthLargest(int[] nums, int k) {
        q(nums, k, 0, nums.length - 1);
        return nums[k - 1];
    }
}

堆排序

堆排序是利用堆数据结构的算法,可以运用大顶堆进行排序,时间复杂度:O(NlogN) ,空间复杂度:O(NlogN)

堆排序代码

class Solution {
    public void swap(int[] nums, int a, int b) {
        int tmp = nums[a];
        nums[a] = nums[b];
        nums[b] = tmp;
    }

    public void down(int[] nums, int k, int n) {
        int t = k, l = k * 2 + 1, r = l + 1;
        // 从根、左、右三个节点中找到最小的一个节点
        if (l < n && nums[t] < nums[l]) t = l;
        if (r < n && nums[t] < nums[r]) t = r;
        if (t != k) { // 说明找到了最小的节点
            swap(nums, t, k);
            // 把最小的节点递归往下放
            down(nums, t, n);
        }
    }

    public int[] sortArray(int[] nums) {
        int n = nums.length;
        // 堆化,只需要从倒数第二层最后一个根节点开始
        for (int i = n / 2 - 1; i >= 0; i--) down(nums, i, n);
        for (int i = n - 1; i >= 0; i--) {
            // 把最大值交换数组末尾
            swap(nums, 0, i);
            // 此处 n 为 i ,相当于把堆中最大元素剔除
            down(nums, 0, i); 
        }

        return nums;
    }
}

基于堆排序

我们只需要使用堆排序,把 K 个数挪动到数组尾部即可,返回 nums.length - K 返回第 K 大的数

class Solution {
    public void swap(int[] nums, int a, int b) {
        int tmp = nums[a];
        nums[a] = nums[b];
        nums[b] = tmp;
    }
    public void down(int[] nums, int k, int n) {
        int t = k, l = k * 2 + 1, r = l + 1;
        if (l < n && nums[t] < nums[l]) t = l;
        if (r < n && nums[t] < nums[r]) t = r;
        if (t != k) {
            swap(nums, t, k);
            down(nums, t, n);
        }
    }
    public int findKthLargest(int[] nums, int k) {
        int n = nums.length;
        for (int i = n / 2 - 1; i >= 0; i--) {
            down(nums, i, n);
        }
        // 将堆中前 K 大的数挪动到队尾
        for (int i = 0; i < k; i++) {
            swap(nums, 0, --n);
            down(nums, 0, n);
        }
        return nums[nums.length - k];
    }
}