开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第20天,点击查看活动详情。
寻找两个正序数组的中位数
描述
给定两个大小为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。
请你找出这两个正序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。
你可以假设 nums1 和 nums2 不会同时为空。
难度
困难
示例
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5
思路
使用归并的方式,合并两个有序数组,得到一个大的有序数组。大的有序数组长度只需要为 (m+n)/2+1即可,大的有序数组的中间位置的元素,即为中位数。但是这种方法的时间复杂度为 O(m+n),不符合题目要求。
要满足 O(log(m+n)) 的时间复杂度就要使用二分查找。官方给出的二分查找思路如下:
根据中位数的定义,当 m+n 是奇数时,中位数是两个有序数组中的第 (m+n)/2 个元素,当 m+n 是偶数时,中位数是两个有序数组中的第 (m+n)/2 个元素和第 (m+n)/2+1 个元素的平均值。这道题可以看出成寻找两个有序数组中的第 k 小的数,其中 k 为 (m+n)/2 或 (m+n)/2+1。
假设有两个有序数组分别为 nums1 和 nums2。要找到第 k 个元素,可以比较 nums1[k/2-1] 和 nums2[k/2-1]。分三种情况:
- 如果 nums1[k/2-1] < nums2[k/2-1],那么比 nums1[k/2-1] 小的数最多有 nums1 的前 k/2-1 个和 nums2 的前 k/2-1 个,即比 nums1[k/2-1] 小的数有 k - 2 个,所以 nums1[k/2-1] 不可能是第 k 小的数,nums1[0] 到 nums1[k/2-1] 也不可能是第 k 个数,进行排除。
- 如果 nums1[k/2-1] > nums2[k/2-1],排除 nums2[0] 到 nums2[k/2-1]。
- 如果 nums1[k/2-1] = nums2[k/2-1],和第一种情况相同。
每次排除之后,需要减少 k 的值,根据减少排除元素的个数减少 k。有以下特殊情况需要处理:
- 如果 nums[k/2-1] 或 nums2[k/2-1] 越界,就选取对应数组的最后一个元素。
- 如果 一个数组为空,说明该数组中的所以元素被排除,直接返回另一个数组中的第 k 个元素。
- 如果 k = 1,返回两个数组中首个最小的元素即可。
代码
Rust
pub struct Solution {}
impl Solution {
pub fn find_median_sorted_arrays(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
// 归并查找,时间复杂度 O(m+n)
// Solution::merge_search(nums1, nums2)
// 二分查找,时间复杂度 O(log(m+n))
Solution::binary_search(nums1, nums2)
}
fn binary_search(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let (len1, len2) = (nums1.len(), nums2.len());
let total_len = len1 + len2;
if total_len % 2 != 0 {
// 奇数
Solution::get_kth_element(nums1, nums2, total_len / 2 + 1) as f64
} else {
// 偶数
let median1 =
Solution::get_kth_element(nums1.clone(), nums2.clone(), total_len / 2);
let median2 =
Solution::get_kth_element(nums1.clone(), nums2.clone(), total_len / 2 + 1);
(median1 + median2) as f64 / 2.0
}
}
/// 获取两个数组中第 k 个元素
fn get_kth_element(nums1: Vec<i32>, nums2: Vec<i32>, k: usize) -> i32 {
let (len1, len2) = (nums1.len(), nums2.len());
let (mut index1, mut index2) = (0, 0);
let mut mut_k = k;
loop {
// 数组中的所有元素都被排除,返回另一个数组中第 k 小的元素
if index1 == len1 {
return nums2[index2 + mut_k - 1];
}
if index2 == len2 {
return nums1[index1 + mut_k - 1];
}
// 如果 k=1,返回两个数组首个元素的最小值
if mut_k == 1 {
return min(nums1[index1], nums2[index2]);
}
let half = mut_k / 2;
let new_index1 = min(index1 + half, len1) - 1;
let new_index2 = min(index2 + half, len2) - 1;
if nums1[new_index1] <= nums2[new_index2] {
// 排除了元素后,k 减少排除元素的个数
mut_k -= new_index1 - index1 + 1;
// 记录排除元素后数组的起始下标
index1 = new_index1 + 1;
} else {
mut_k -= new_index2 - index2 + 1;
index2 = new_index2 + 1;
}
}
}
#[allow(dead_code)]
fn merge_search(nums1: Vec<i32>, nums2: Vec<i32>) -> f64 {
let (mut index1, mut index2, mut merge_num_index) = (0, 0, 0);
let (len1, len2) = (nums1.len(), nums2.len());
// 创建一个向量用来保存合并的元素,只需要保存到中间的元素即可
let mut merge_nums = vec![0; (len1 + len2) / 2 + 1];
while merge_num_index < merge_nums.len() {
let mut value1 = 1 << 31 - 1;
let mut value2 = 1 << 31 - 1;
if index1 < len1 {
value1 = nums1[index1];
}
if index2 < len2 {
value2 = nums2[index2];
}
if value1 <= value2 {
merge_nums[merge_num_index] = value1;
index1 += 1;
} else {
merge_nums[merge_num_index] = value2;
index2 += 2;
}
merge_num_index += 1;
}
if (len1 + len2) & 2 != 0 {
// 奇数取最后一个元素
merge_nums[merge_nums.len() - 1] as f64
} else {
// 偶数取最后两个元素的平均值
(merge_nums[merge_nums.len() - 2] + merge_nums[merge_nums.len() - 1]) as f64 / 2.0
}
}
}
#[test]
fn test_find_median_sorted_arrays() {
let nums1: Vec<i32> = vec![1, 3];
let nums2: Vec<i32> = vec![2];
let median = Solution::find_median_sorted_arrays(nums1.clone(), nums2.clone());
println!("{:?}", nums1);
println!("{:?}", nums2);
println!("中位数是 {}", median);
let nums1: Vec<i32> = vec![1, 2];
let nums2: Vec<i32> = vec![3, 4];
let median = Solution::find_median_sorted_arrays(nums1.clone(), nums2.clone());
println!("{:?}", nums1);
println!("{:?}", nums2);
println!("中位数是 {}", median);
}
运行结果:
[1, 3]
[2]
中位数是 2
[1, 2]
[3, 4]
中位数是 2.5
Go
func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
// 归并查找,时间复杂度 O(m+n)
// return mergeSearch(nums1, nums2)
// 二分查找,时间复杂度 O(log(m+n))
return binarySearch(nums1, nums2)
}
func binarySearch(nums1 []int, nums2 []int) float64 {
len1, len2 := len(nums1), len(nums2)
totalLen := len1 + len2
if totalLen % 2 != 0 { // 奇数
median := getKthElement(nums1, nums2, totalLen / 2 + 1)
return float64(median)
} else { // 偶数
median1 := getKthElement(nums1, nums2, totalLen / 2)
median2 := getKthElement(nums1, nums2, totalLen / 2 + 1)
return float64(median1 + median2) / 2.0
}
}
// 获取两个数组中第 k 个元素
func getKthElement(nums1 []int, nums2 []int, k int) int {
len1, len2 := len(nums1), len(nums2)
index1, index2 := 0, 0
for {
// 数组中的所有元素都被排除,返回另一个数组中第 k 小的元素
if index1 == len1 {
return nums2[index2 + k - 1]
}
if index2 == len2 {
return nums1[index1 + k - 1]
}
// 如果 k=1,返回两个数组首个元素的最小值
if k == 1 {
return int(math.Min(float64(nums1[index1]), float64(nums2[index2])))
}
half := k / 2
newIndex1 := int(math.Min(float64(index1 + half), float64(len1))) - 1
newIndex2 := int(math.Min(float64(index2 + half), float64(len2))) - 1
if nums1[newIndex1] <= nums2[newIndex2] {
// 排除了元素后,k 减少排除元素的个数
k -= newIndex1 - index1 + 1
// 记录排除元素后数组的起始下标
index1 = newIndex1 + 1
} else {
k -= newIndex2 - index2 + 1
index2 = newIndex2 + 1
}
}
}
func mergeSearch(nums1 []int, nums2 []int) float64 {
index1, index2, mergeNumIndex := 0, 0, 0
len1, len2 := len(nums1), len(nums2)
value1, value2 := 0, 0
// 创建一个数组用来保存合并的元素,只需要保存到中间的元素即可
mergeNums := make([]int, (len1 + len2) / 2 + 1)
for mergeNumIndex < len(mergeNums) {
value1 = 1 << 31 - 1
value2 = 1 << 31 - 1
if index1 < len1 {
value1 = nums1[index1]
}
if index2 < len2 {
value2 = nums2[index2]
}
if value1 <= value2 {
mergeNums[mergeNumIndex] = nums1[index1]
index1++
} else {
mergeNums[mergeNumIndex] = nums2[index2]
index2++
}
mergeNumIndex++
}
// 奇数取最后一个元素
if (len1 + len2) % 2 != 0 {
return float64(mergeNums[len(mergeNums) - 1])
}
// 偶数取最后两个元素的平均值
return (float64(mergeNums[len(mergeNums) - 2]) + float64(mergeNums[len(mergeNums) - 1])) / 2
}
func TestFindMedianSortedArrays(t *testing.T) {
nums1 := []int{1, 3}
nums2 := []int{2}
median := findMedianSortedArrays(nums1, nums2)
t.Log(nums1)
t.Log(nums2)
t.Logf("中位数是 %f\n", median)
nums1 = []int{1, 2}
nums2 = []int{3, 4}
median = findMedianSortedArrays(nums1, nums2)
t.Log(nums1)
t.Log(nums2)
t.Logf("中位数是 %f\n", median)
}
运行结果:
[1 3]
[2]
中位数是 2.000000
[1 2]
[3 4]
中位数是 2.500000
Java
public class Main {
public static double findMedianSortedArrays(int[] nums1, int[] nums2) {
// 归并查找,时间复杂度 O(m+n)
// return mergeSearch(nums1, nums2);
// 二分查找,时间复杂度 O(log(m+n))
return binarySearch(nums1, nums2);
}
public static double binarySearch(int[] nums1, int[] nums2) {
int len1 = nums1.length, len2 = nums2.length;
int totalLen = len1 + len2;
if (totalLen % 2 != 0) { // 奇数
return getKthElement(nums1, nums2, totalLen / 2 + 1);
} else { // 偶数
double median1 = getKthElement(nums1, nums2, totalLen / 2);
double median2 = getKthElement(nums1, nums2, totalLen / 2 + 1);
return (median1 + median2) / 2.0;
}
}
/**
* 获取两个数组中第 k 个元素
*/
private static int getKthElement(int[] nums1, int[] nums2, int k) {
int len1 = nums1.length, len2 = nums1.length;
int index1 = 0, index2 = 0;
while (true) {
// 数组中的所有元素都被排除,返回另一个数组中第 k 小的元素
if (index1 == len1) {
return nums2[index2 + k - 1];
}
if (index2 == len2) {
return nums1[index1 + k - 1];
}
// 如果 k=1,返回两个数组首个元素的最小值
if (k == 1) {
return Math.min(nums1[index1], nums2[index2]);
}
int half = k / 2;
int newIndex1 = Math.min(index1 + half, len1) - 1;
int newIndex2 = Math.min(index2 + half, len2) - 1;
if (nums1[newIndex1] <= nums2[newIndex2]) {
// 排除了元素后,k 减少排除元素的个数
k -= newIndex1 - index1 + 1;
// 记录排除元素后数组的起始下标
index1 = newIndex1 + 1;
} else {
k -= newIndex2 - index2 + 1;
index2 = newIndex2 + 1;
}
}
}
public static double mergeSearch(int[] nums1, int[] nums2) {
int index1 = 0, index2 = 0, mergeNumIndex = 0;
int len1 = nums1.length, len2 = nums2.length;
int value1 = 0, value2 = 0;
// 创建一个数组用来保存合并的元素,只需要保存到中间的元素即可
int[] mergeNums = new int[(len1 + len2) / 2 + 1];
while (mergeNumIndex < mergeNums.length) {
value1 = Integer.MAX_VALUE;
value2 = Integer.MAX_VALUE;
if (index1 < len1) {
value1 = nums1[index1];
}
if (index2 < len2) {
value2 = nums1[index2];
}
if (value1 <= value2) {
mergeNums[mergeNumIndex++] = nums1[index1++];
} else {
mergeNums[mergeNumIndex++] = nums2[index2++];
}
}
// 奇数取最后一个元素
if ((len1 + len2) % 2 != 0) {
return mergeNums[mergeNums.length - 1];
}
// 偶数取最后两个元素的平均值
return (mergeNums[mergeNums.length - 2] + mergeNums[mergeNums.length - 1]) / 2.0;
}
public static void main(String[] args) {
int[] nums1 = new int[]{1, 3};
int[] nums2 = new int[]{2};
double median = findMedianSortedArrays(nums1, nums2);
System.out.println(Arrays.toString(nums1));
System.out.println(Arrays.toString(nums2));
System.out.printf("中位数是 %f\n", median);
nums1 = new int[]{1, 2};
nums2 = new int[]{3, 4};
median = findMedianSortedArrays(nums1, nums2);
System.out.println(Arrays.toString(nums1));
System.out.println(Arrays.toString(nums2));
System.out.printf("中位数是 %f\n", median);
}
}
运行结果:
[1, 3]
[2]
中位数是 2.000000
[1, 2]
[3, 4]
中位数是 2.500000