【LeetCode】寻找两个正序数组的中位数

87 阅读6分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第20天,点击查看活动详情

寻找两个正序数组的中位数

原题:leetcode-cn.com/problems/me…

描述

给定两个大小为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。

请你找出这两个正序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。

你可以假设 nums1 和 nums2 不会同时为空。

难度

困难

示例

nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5

思路

使用归并的方式,合并两个有序数组,得到一个大的有序数组。大的有序数组长度只需要为 (m+n)/2+1即可,大的有序数组的中间位置的元素,即为中位数。但是这种方法的时间复杂度为 O(m+n),不符合题目要求。

要满足 O(log(m+n)) 的时间复杂度就要使用二分查找。官方给出的二分查找思路如下:

根据中位数的定义,当 m+n 是奇数时,中位数是两个有序数组中的第 (m+n)/2 个元素,当 m+n 是偶数时,中位数是两个有序数组中的第 (m+n)/2 个元素和第 (m+n)/2+1 个元素的平均值。这道题可以看出成寻找两个有序数组中的第 k 小的数,其中 k 为 (m+n)/2 或 (m+n)/2+1。

假设有两个有序数组分别为 nums1 和 nums2。要找到第 k 个元素,可以比较 nums1[k/2-1] 和 nums2[k/2-1]。分三种情况:

  • 如果 nums1[k/2-1] < nums2[k/2-1],那么比 nums1[k/2-1] 小的数最多有 nums1 的前 k/2-1 个和 nums2 的前 k/2-1 个,即比 nums1[k/2-1] 小的数有 k - 2 个,所以 nums1[k/2-1] 不可能是第 k 小的数,nums1[0] 到 nums1[k/2-1] 也不可能是第 k 个数,进行排除。
  • 如果 nums1[k/2-1] > nums2[k/2-1],排除 nums2[0] 到 nums2[k/2-1]。
  • 如果 nums1[k/2-1] = nums2[k/2-1],和第一种情况相同。

每次排除之后,需要减少 k 的值,根据减少排除元素的个数减少 k。有以下特殊情况需要处理:

  • 如果 nums[k/2-1] 或 nums2[k/2-1] 越界,就选取对应数组的最后一个元素。
  • 如果 一个数组为空,说明该数组中的所以元素被排除,直接返回另一个数组中的第 k 个元素。
  • 如果 k = 1,返回两个数组中首个最小的元素即可。

代码

Rust

pub struct Solution {}
​
impl Solution {
    pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
        // 归并查找,时间复杂度 O(m+n)
        // Solution::merge_search(nums1, nums2)
​
        // 二分查找,时间复杂度 O(log(m+n))
        Solution::binary_search(nums1, nums2)
    }
​
    fn binary_search(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
        let (len1, len2) = (nums1.len(), nums2.len());
        let total_len = len1 + len2;
        if total_len % 2 != 0 {
            // 奇数
            Solution::get_kth_element(nums1, nums2, total_len / 2 + 1) as f64
        } else {
            // 偶数
            let median1 =
                Solution::get_kth_element(nums1.clone(), nums2.clone(), total_len / 2);
            let median2 =
                Solution::get_kth_element(nums1.clone(), nums2.clone(), total_len / 2 + 1);
            (median1 + median2) as f64 / 2.0
        }
    }
​
    /// 获取两个数组中第 k 个元素
    fn get_kth_element(nums1: Vec<i32>, nums2: Vec<i32>, k: usize) -> i32 {
        let (len1, len2) = (nums1.len(), nums2.len());
        let (mut index1, mut index2) = (0, 0);
        let mut mut_k = k;
        loop {
            // 数组中的所有元素都被排除,返回另一个数组中第 k 小的元素
            if index1 == len1 {
                return nums2[index2 + mut_k - 1];
            }
            if index2 == len2 {
                return nums1[index1 + mut_k - 1];
            }
            // 如果 k=1,返回两个数组首个元素的最小值
            if mut_k == 1 {
                return min(nums1[index1], nums2[index2]);
            }
​
            let half = mut_k / 2;
            let new_index1 = min(index1 + half, len1) - 1;
            let new_index2 = min(index2 + half, len2) - 1;
            if nums1[new_index1] <= nums2[new_index2] {
                // 排除了元素后,k 减少排除元素的个数
                mut_k -= new_index1 - index1 + 1;
                // 记录排除元素后数组的起始下标
                index1 = new_index1 + 1;
            } else {
                mut_k -= new_index2 - index2 + 1;
                index2 = new_index2 + 1;
            }
        }
    }
​
    #[allow(dead_code)]
    fn merge_search(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
        let (mut index1, mut index2, mut merge_num_index) = (0, 0, 0);
        let (len1, len2) = (nums1.len(), nums2.len());
​
        // 创建一个向量用来保存合并的元素,只需要保存到中间的元素即可
        let mut merge_nums = vec![0; (len1 + len2) / 2 + 1];
        while merge_num_index < merge_nums.len() {
            let mut value1 = 1 << 31 - 1;
            let mut value2 = 1 << 31 - 1;
            if index1 < len1 {
                value1 = nums1[index1];
            }
            if index2 < len2 {
                value2 = nums2[index2];
            }
            if value1 <= value2 {
                merge_nums[merge_num_index] = value1;
                index1 += 1;
            } else {
                merge_nums[merge_num_index] = value2;
                index2 += 2;
            }
            merge_num_index += 1;
        }
​
        if (len1 + len2) & 2 != 0 {
            // 奇数取最后一个元素
            merge_nums[merge_nums.len() - 1] as f64
        } else {
            // 偶数取最后两个元素的平均值
            (merge_nums[merge_nums.len() - 2] + merge_nums[merge_nums.len() - 1]) as f64 / 2.0
        }
    }
}
#[test]
fn test_find_median_sorted_arrays() {
    let nums1: Vec<i32> = vec![1, 3];
    let nums2: Vec<i32> = vec![2];
    let median = Solution::find_median_sorted_arrays(nums1.clone(), nums2.clone());
​
    println!("{:?}", nums1);
    println!("{:?}", nums2);
    println!("中位数是 {}", median);
​
    let nums1: Vec<i32> = vec![1, 2];
    let nums2: Vec<i32> = vec![3, 4];
    let median = Solution::find_median_sorted_arrays(nums1.clone(), nums2.clone());
​
    println!("{:?}", nums1);
    println!("{:?}", nums2);
    println!("中位数是 {}", median);
}

运行结果:

[1, 3]
[2]
中位数是 2
[1, 2]
[3, 4]
中位数是 2.5

Go

func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
    // 归并查找,时间复杂度 O(m+n)
    // return mergeSearch(nums1, nums2)
​
    // 二分查找,时间复杂度 O(log(m+n))
    return binarySearch(nums1, nums2)
}
​
func binarySearch(nums1 []int, nums2 []int) float64 {
    len1, len2 := len(nums1), len(nums2)
    totalLen := len1 + len2
    if totalLen % 2 != 0 { // 奇数
        median := getKthElement(nums1, nums2, totalLen / 2 + 1)
        return float64(median)
    } else { // 偶数
        median1 := getKthElement(nums1, nums2, totalLen / 2)
        median2 := getKthElement(nums1, nums2, totalLen / 2 + 1)
        return float64(median1 + median2) / 2.0
    }
}
// 获取两个数组中第 k 个元素
func getKthElement(nums1 []int, nums2 []int, k int) int {
    len1, len2 := len(nums1), len(nums2)
    index1, index2 := 0, 0
    for {
​
        // 数组中的所有元素都被排除,返回另一个数组中第 k 小的元素
        if index1 == len1 {
            return nums2[index2 + k - 1]
        }
        if index2 == len2 {
            return nums1[index1 + k - 1]
        }
​
        // 如果 k=1,返回两个数组首个元素的最小值
        if k == 1 {
            return int(math.Min(float64(nums1[index1]), float64(nums2[index2])))
        }
​
        half := k / 2
        newIndex1 := int(math.Min(float64(index1 + half), float64(len1))) - 1
        newIndex2 := int(math.Min(float64(index2 + half), float64(len2))) - 1
        if nums1[newIndex1] <= nums2[newIndex2] {
            // 排除了元素后,k 减少排除元素的个数
            k -= newIndex1 - index1 + 1
            // 记录排除元素后数组的起始下标
            index1 = newIndex1 + 1
        } else {
            k -= newIndex2 - index2 + 1
            index2 = newIndex2 + 1
        }
    }
}
​
func mergeSearch(nums1 []int, nums2 []int) float64 {
    index1, index2, mergeNumIndex := 0, 0, 0
    len1, len2 := len(nums1), len(nums2)
    value1, value2 := 0, 0
    // 创建一个数组用来保存合并的元素,只需要保存到中间的元素即可
    mergeNums := make([]int, (len1 + len2) / 2 + 1)
    for mergeNumIndex < len(mergeNums) {
        value1 = 1 << 31 - 1
        value2 = 1 << 31 - 1
        if index1 < len1 {
            value1 = nums1[index1]
        }
        if index2 < len2 {
            value2 = nums2[index2]
        }
        if value1 <= value2 {
            mergeNums[mergeNumIndex] = nums1[index1]
            index1++
        } else {
            mergeNums[mergeNumIndex] = nums2[index2]
            index2++
        }
        mergeNumIndex++
    }
​
    // 奇数取最后一个元素
    if (len1 + len2) % 2 != 0 {
        return float64(mergeNums[len(mergeNums) - 1])
    }
    // 偶数取最后两个元素的平均值
    return (float64(mergeNums[len(mergeNums) - 2]) + float64(mergeNums[len(mergeNums) - 1])) / 2
}
func TestFindMedianSortedArrays(t *testing.T) {
    nums1 := []int{1, 3}
    nums2 := []int{2}
    median := findMedianSortedArrays(nums1, nums2)
​
    t.Log(nums1)
    t.Log(nums2)
    t.Logf("中位数是 %f\n", median)
​
    nums1 = []int{1, 2}
    nums2 = []int{3, 4}
    median = findMedianSortedArrays(nums1, nums2)
​
    t.Log(nums1)
    t.Log(nums2)
    t.Logf("中位数是 %f\n", median)
}

运行结果:

[1 3]
[2]
中位数是 2.000000
[1 2]
[3 4]
中位数是 2.500000

Java

public class Main {
​
    public static double findMedianSortedArrays(int[] nums1, int[] nums2) {
        // 归并查找,时间复杂度 O(m+n)
        // return mergeSearch(nums1, nums2);
​
        // 二分查找,时间复杂度 O(log(m+n))
        return binarySearch(nums1, nums2);
    }
​
    public static double binarySearch(int[] nums1, int[] nums2) {
        int len1 = nums1.length, len2 = nums2.length;
        int totalLen = len1 + len2;
        if (totalLen % 2 != 0) { // 奇数
            return getKthElement(nums1, nums2, totalLen / 2 + 1);
        } else { // 偶数
            double median1 = getKthElement(nums1, nums2, totalLen / 2);
            double median2 = getKthElement(nums1, nums2, totalLen / 2 + 1);
            return (median1 + median2) / 2.0;
        }
    }
​
    /**
     * 获取两个数组中第 k 个元素
     */
    private static int getKthElement(int[] nums1, int[] nums2, int k) {
        int len1 = nums1.length, len2 = nums1.length;
        int index1 = 0, index2 = 0;
        while (true) {
​
            // 数组中的所有元素都被排除,返回另一个数组中第 k 小的元素
            if (index1 == len1) {
                return nums2[index2 + k - 1];
            }
            if (index2 == len2) {
                return nums1[index1 + k - 1];
            }
​
            // 如果 k=1,返回两个数组首个元素的最小值
            if (k == 1) {
                return Math.min(nums1[index1], nums2[index2]);
            }
​
            int half = k / 2;
            int newIndex1 = Math.min(index1 + half, len1) - 1;
            int newIndex2 = Math.min(index2 + half, len2) - 1;
            if (nums1[newIndex1] <= nums2[newIndex2]) {
                // 排除了元素后,k 减少排除元素的个数
                k -= newIndex1 - index1 + 1;
                // 记录排除元素后数组的起始下标
                index1 = newIndex1 + 1;
            } else {
                k -= newIndex2 - index2 + 1;
                index2 = newIndex2 + 1;
            }
        }
    }
​
    public static double mergeSearch(int[] nums1, int[] nums2) {
        int index1 = 0, index2 = 0, mergeNumIndex = 0;
        int len1 = nums1.length, len2 = nums2.length;
        int value1 = 0, value2 = 0;
        // 创建一个数组用来保存合并的元素,只需要保存到中间的元素即可
        int[] mergeNums = new int[(len1 + len2) / 2 + 1];
        while (mergeNumIndex < mergeNums.length) {
            value1 = Integer.MAX_VALUE;
            value2 = Integer.MAX_VALUE;
            if (index1 < len1) {
                value1 = nums1[index1];
            }
            if (index2 < len2) {
                value2 = nums1[index2];
            }
            if (value1 <= value2) {
                mergeNums[mergeNumIndex++] = nums1[index1++];
            } else {
                mergeNums[mergeNumIndex++] = nums2[index2++];
            }
        }
​
        // 奇数取最后一个元素
        if ((len1 + len2) % 2 != 0) {
            return mergeNums[mergeNums.length - 1];
        }
        // 偶数取最后两个元素的平均值
        return (mergeNums[mergeNums.length - 2] + mergeNums[mergeNums.length - 1]) / 2.0;
    }
​
    public static void main(String[] args) {
        int[] nums1 = new int[]{1, 3};
        int[] nums2 = new int[]{2};
        double median = findMedianSortedArrays(nums1, nums2);
​
        System.out.println(Arrays.toString(nums1));
        System.out.println(Arrays.toString(nums2));
        System.out.printf("中位数是 %f\n", median);
​
        nums1 = new int[]{1, 2};
        nums2 = new int[]{3, 4};
        median = findMedianSortedArrays(nums1, nums2);
​
        System.out.println(Arrays.toString(nums1));
        System.out.println(Arrays.toString(nums2));
        System.out.printf("中位数是 %f\n", median);
    }
}

运行结果:

[1, 3]
[2]
中位数是 2.000000
[1, 2]
[3, 4]
中位数是 2.500000