📌 题目链接:4. 寻找两个正序数组的中位数 - 力扣(LeetCode)
🔍 难度:困难 | 🏷️ 标签:数组、二分查找、分治
⏱️ 目标时间复杂度:O(log(m + n))
💾 空间复杂度:O(1)
🧩 题目分析
给定两个已排序(升序)的整数数组 nums1 和 nums2,要求找出这两个数组合并后的中位数。关键限制是:时间复杂度必须为 O(log(m + n)) 。
💡 什么是中位数?
- 若总长度为奇数,则中位数是第
(m+n+1)/2小的数;- 若为偶数,则是第
(m+n)/2与(m+n)/2 + 1小的数的平均值。
最朴素的想法是归并排序(双指针合并),但时间复杂度为 O(m + n),不满足题目要求。因此,必须使用对数级别的算法——二分查找。
本题是 LeetCode 中少数要求 严格 O(log(m+n)) 的题目之一,也是面试高频难题,常被用于考察候选人对二分思想本质的理解。
🔍 核心算法及代码讲解
本题有两种主流解法,均能达到 O(log(min(m, n))) 或 O(log(m + n)) 的时间复杂度:
✅ 方法一:二分查找第 k 小元素(推荐掌握)
核心思想:将“找中位数”转化为“找第 k 小的数”,利用每次排除 k/2 个不可能的元素,实现对数级缩小搜索空间。
📌 关键洞察:
- 要找第
k小的数,比较nums1[k/2 - 1]与nums2[k/2 - 1]。 - 若
nums1[k/2 - 1] <= nums2[k/2 - 1],则nums1[0..k/2-1]都不可能是第 k 小(因为最多只有k-2个数比它小),可安全排除。 - 每次排除约
k/2个元素,k减少相应数量,递归或迭代继续。
🛑 边界处理(面试必问!):
- 某数组已遍历完 → 直接返回另一数组的第
k个元素; - k == 1 → 返回两数组当前首元素的较小值;
- k/2 越界 → 取该数组最后一个元素,并按实际排除数量更新
k。
💻 C++ 核心函数(带详细行注释):
int getKthElement(const vector<int>& nums1, const vector<int>& nums2, int k) {
/* 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较
* 这里的 "/" 表示整除
* nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个
* nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个
* 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个
* 这样 pivot 本身最大也只能是第 k-1 小的元素
* 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组
* 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组
* 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数
*/
int m = nums1.size();
int n = nums2.size();
int index1 = 0, index2 = 0; // 当前有效起始下标
while (true) {
// 边界情况1:nums1 已用完,直接返回 nums2 的第 k 个
if (index1 == m) {
return nums2[index2 + k - 1];
}
// 边界情况2:nums2 已用完
if (index2 == n) {
return nums1[index1 + k - 1];
}
// 边界情况3:k=1,取当前最小值
if (k == 1) {
return min(nums1[index1], nums2[index2]);
}
// 正常情况:计算新下标,防止越界
int newIndex1 = min(index1 + k / 2 - 1, m - 1);
int newIndex2 = min(index2 + k / 2 - 1, n - 1);
int pivot1 = nums1[newIndex1];
int pivot2 = nums2[newIndex2];
if (pivot1 <= pivot2) {
// 排除 nums1[index1 .. newIndex1]
k -= newIndex1 - index1 + 1; // 实际排除的数量
index1 = newIndex1 + 1;
} else {
// 排除 nums2[index2 .. newIndex2]
k -= newIndex2 - index2 + 1;
index2 = newIndex2 + 1;
}
}
}
✅ 为什么时间复杂度是 O(log(m+n))?
每次循环至少排除k/2个元素,k初始为(m+n)/2,故最多log(k)次操作,即 O(log(m+n)) 。
✅ 方法二:划分数组(更优,O(log(min(m,n))))
核心思想:在较短数组上二分划分点
i,使得左半部分最大值 ≤ 右半部分最小值。
📌 关键等式:
设总长度 L = m + n,我们希望:
len(left_part) = (L + 1) / 2(奇数时左多1)max(left_part) <= min(right_part)
令 i 为 nums1 的划分点(0 <= i <= m),则 j = (m + n + 1)/2 - i 为 nums2 的划分点。
需满足:
nums1[i-1] <= nums2[j] && nums2[j-1] <= nums1[i]
由于 i 增大 → nums1[i-1] 增大,nums2[j] 减小,具有单调性,可用二分查找最大满足 nums1[i-1] <= nums2[j] 的 i。
💡 为什么交换数组保证 m <= n?
避免j为负数,确保j = (m+n+1)/2 - i >= 0。
此方法时间复杂度为 O(log(min(m, n))) ,略优于方法一,但理解门槛更高。
🧭 解题思路(分步骤)
方法一(第 k 小元素)步骤:
-
判断总长度奇偶:
- 奇数 → 找第
(m+n+1)/2小; - 偶数 → 找第
(m+n)/2和(m+n)/2 + 1小,取平均。
- 奇数 → 找第
-
设计
getKthElement函数:- 维护两个指针
index1,index2表示当前有效起始位置; - 循环直到满足边界条件;
- 每次比较
k/2位置的元素,排除较小者所在数组的前k/2个; - 更新
k和对应指针。
- 维护两个指针
-
处理越界和边界:如数组耗尽、k=1 等。
方法二(划分数组)步骤:
-
确保
nums1是较短数组,否则交换; -
在
[0, m]上二分i; -
计算
j = (m+n+1)/2 - i; -
检查划分是否合法(
left_max <= right_min); -
根据结果调整二分边界;
-
计算中位数:
- 奇数 →
max(left_part) - 偶数 →
(max(left) + min(right)) / 2.0
- 奇数 →
📊 算法分析
| 方法 | 时间复杂度 | 空间复杂度 | 面试推荐度 | 难度 |
|---|---|---|---|---|
| 归并(暴力) | O(m + n) | O(1) | ❌ 不满足要求 | ⭐ |
| 第 k 小(二分) | O(log(m + n)) | O(1) | ✅✅✅ 强烈推荐 | ⭐⭐⭐⭐ |
| 划分数组 | O(log(min(m, n))) | O(1) | ✅✅ 高阶技巧 | ⭐⭐⭐⭐⭐ |
🎯 面试建议:优先掌握方法一,逻辑清晰,易于解释;若时间充裕,可补充方法二展示深度。
💻 代码
C++ 完整代码
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
class Solution {
public:
int getKthElement(const vector<int>& nums1, const vector<int>& nums2, int k) {
/* 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较
* 这里的 "/" 表示整除
* nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个
* nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个
* 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个
* 这样 pivot 本身最大也只能是第 k-1 小的元素
* 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组
* 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组
* 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数
*/
int m = nums1.size();
int n = nums2.size();
int index1 = 0, index2 = 0;
while (true) {
// 边界情况
if (index1 == m) {
return nums2[index2 + k - 1];
}
if (index2 == n) {
return nums1[index1 + k - 1];
}
if (k == 1) {
return min(nums1[index1], nums2[index2]);
}
// 正常情况
int newIndex1 = min(index1 + k / 2 - 1, m - 1);
int newIndex2 = min(index2 + k / 2 - 1, n - 1);
int pivot1 = nums1[newIndex1];
int pivot2 = nums2[newIndex2];
if (pivot1 <= pivot2) {
k -= newIndex1 - index1 + 1;
index1 = newIndex1 + 1;
}
else {
k -= newIndex2 - index2 + 1;
index2 = newIndex2 + 1;
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int totalLength = nums1.size() + nums2.size();
if (totalLength % 2 == 1) {
return getKthElement(nums1, nums2, (totalLength + 1) / 2);
}
else {
return (getKthElement(nums1, nums2, totalLength / 2) + getKthElement(nums1, nums2, totalLength / 2 + 1)) / 2.0;
}
}
};
// 测试
signed main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
Solution sol;
// 测试用例1
vector<int> nums1 = {1, 3};
vector<int> nums2 = {2};
cout << fixed << setprecision(5) << sol.findMedianSortedArrays(nums1, nums2) << "\n"; // 2.00000
// 测试用例2
nums1 = {1, 2};
nums2 = {3, 4};
cout << fixed << setprecision(5) << sol.findMedianSortedArrays(nums1, nums2) << "\n"; // 2.50000
return 0;
}
JavaScript 完整代码
/**
* @param {number[]} nums1
* @param {number[]} nums2
* @return {number}
*/
var findMedianSortedArrays = function(nums1, nums2) {
const getKthElement = (arr1, arr2, k) => {
let index1 = 0, index2 = 0;
const m = arr1.length, n = arr2.length;
while (true) {
if (index1 === m) return arr2[index2 + k - 1];
if (index2 === n) return arr1[index1 + k - 1];
if (k === 1) return Math.min(arr1[index1], arr2[index2]);
let newIndex1 = Math.min(index1 + Math.floor(k / 2) - 1, m - 1);
let newIndex2 = Math.min(index2 + Math.floor(k / 2) - 1, n - 1);
let pivot1 = arr1[newIndex1];
let pivot2 = arr2[newIndex2];
if (pivot1 <= pivot2) {
k -= newIndex1 - index1 + 1;
index1 = newIndex1 + 1;
} else {
k -= newIndex2 - index2 + 1;
index2 = newIndex2 + 1;
}
}
};
const total = nums1.length + nums2.length;
if (total % 2 === 1) {
return getKthElement(nums1, nums2, Math.floor((total + 1) / 2));
} else {
const left = getKthElement(nums1, nums2, total / 2);
const right = getKthElement(nums1, nums2, total / 2 + 1);
return (left + right) / 2.0;
}
};
// 测试
console.log(findMedianSortedArrays([1, 3], [2])); // 2
console.log(findMedianSortedArrays([1, 2], [3, 4])); // 2.5
🌟 本期完结,下期见!🔥
👉 点赞收藏加关注,新文更新不迷路。关注专栏【算法】LeetCode Hot100刷题日记,持续为你拆解每一道热题的底层逻辑与面试技巧!
💬 欢迎留言交流你的解法或疑问!一起进步,冲向 Offer!💪
📌 记住:当你在刷题时,不要只看答案,要像写这篇文章一样,深入思考每一步背后的原理、优化空间和面试价值。这才是真正提升算法能力的方式!