力扣4. 寻找两个正序数组的中位数

29 阅读4分钟

image.png

从暴力合并到对数级优化:深度拆解“寻找两个有序数组的中位数”

本文将带你从最直观的暴力解法出发,一步步推导至最优的“二分排除法”。


一、 暴力解法:直觉的合并(Merge)

面对两个有序数组,最直观的想法就是:既然它们都有序,我把它们合并成一个大数组,中位数不就手到擒来了吗?

1. 思路

  1. 开辟一个长度为 m+nm+n 的新数组。
  2. 使用双指针,依次比较两个数组的元素,从小到大填入新数组。
  3. 根据新数组长度的奇偶性,直接取中间位置的值。

2. 复杂度分析

  • 时间复杂度O(m+n)O(m+n)。需要遍历完两个数组。
  • 空间复杂度O(m+n)O(m+n)。需要额外空间存储合并后的数组。

结论:这种做法逻辑最简单,但无法满足题目 O(log(m+n))O(\log(m+n)) 的性能要求。


二、 核心进阶:寻找“第 kk 小”的数

要达到 O(log)O(\log) 级别,我们必须跳出“合并”的思维。

中位数的本质是:寻找合并后序列中排名第 kk 的数字。

  • 如果总长度 LL 是奇数,中位数就是第 (L+1)/2(L+1)/2 小的数。
  • 如果总长度 LL 是偶数,中位数就是第 L/2L/2 和第 (L/2+1)(L/2 + 1) 小两个数的平均值。

为了简化逻辑,我们可以使用一个小技巧:

无论奇偶,统一寻找第 (m+n+1)/2(m+n+1)/2 小和第 (m+n+2)/2(m+n+2)/2 小的数,最后求平均。


三、 最优方案:二分排除法 (Binary Search)

如何快速找到第 kk 小的数?核心在于:每次排除掉一部分肯定不是答案的数字。

1. 排除策略

假设我们要找第 kk 小的数。我们在两个数组中分别取第 k/2k/2 个元素进行比较:

  • 数组 A 的第 k/2k/2 个元素记为 midVal1
  • 数组 B 的第 k/2k/2 个元素记为 midVal2

如果 midVal1 < midVal2

这说明 A 数组的前 k/2k/2 个元素在合并后的总排名中,绝对不可能达到第 kk(它们最多排在第 k1k-1 名)。

操作:直接“扔掉” A 数组的前 k/2k/2 个元素,剩下的任务是在剩余的数字中找第 kk/2k - k/2 小的数。

2. 越界处理

如果某个数组非常短,根本没有第 k/2k/2 个元素怎么办?

技巧:给它赋一个正无穷大 Integer.MAX_VALUE。这样它在比较中一定会“输”,从而迫使程序去排除另一个长度足够的数组。


四、 实例追踪:手把手模拟过程

场景A = [1, 3, 4, 9]B = [1, 2, 3, 5, 6, 7, 8]

目标:总长度 11,寻找第 k=6k=6 小的数。

步骤剩余 k比较对象结果动作
1k=6k=6A[2]=4 vs B[2]=34>34 > 3排除 B 的前 3 个 [1,2,3]kk 变为 63=36-3=3
2k=3k=3A[0]=1 vs B[3]=51<51 < 5排除 A 的前 1 个 [1]kk 变为 31=23-1=2
3k=2k=2A[1]=3 vs B[3]=53<53 < 5排除 A 的前 1 个 [3]kk 变为 21=12-1=1
4k=1k=1A[2]=4 vs B[3]=5k=1 终止取两者最小值:4

验证:合并后为 [1, 1, 2, 3, 3, (4), 5, 6, 7, 8, 9],第 6 个确实是 4。


五、 代码实现 (Java)

Java

class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        int m = nums1.length;
        int n = nums2.length;
        // 技巧:统一处理奇偶数
        int left = (m + n + 1) / 2;
        int right = (m + n + 2) / 2;
        return (findKth(nums1, 0, nums2, 0, left) + 
                findKth(nums1, 0, nums2, 0, right)) / 2.0;
    }

    // i, j 分别为两个数组当前的有效起始下标
    private int findKth(int[] nums1, int i, int[] nums2, int j, int k) {
        // 边界情况:一个数组已排空
        if (i >= nums1.length) return nums2[j + k - 1];
        if (j >= nums2.length) return nums1[i + k - 1];
        
        // 递归终点:找最小的那个
        if (k == 1) return Math.min(nums1[i], nums2[j]);

        // 核心:二分排除逻辑
        int midVal1 = (i + k / 2 - 1 < nums1.length) ? nums1[i + k / 2 - 1] : Integer.MAX_VALUE;
        int midVal2 = (j + k / 2 - 1 < nums2.length) ? nums2[j + k / 2 - 1] : Integer.MAX_VALUE;

        if (midVal1 < midVal2) {
            // 排除 nums1 的前 k/2 个元素
            return findKth(nums1, i + k / 2, nums2, j, k - k / 2);
        } else {
            // 排除 nums2 的前 k/2 个元素
            return findKth(nums1, i, nums2, j + k / 2, k - k / 2);
        }
    }
}

六、 总结

这道题的精髓在于**“排除法”**。

  1. O(m+n)O(m+n)O(log(m+n))O(\log(m+n)) :利用有序特性,从“逐个扫描”进化为“成批剔除”。
  2. kk 的妙用:将寻找中位数转化为寻找第 kk 小元素,使逻辑更具通用性。
  3. 细节控:使用 Integer.MAX_VALUE 优雅处理数组越界,使用 (m+n+1)/2 统一奇偶处理。

在面试中,如果你能清晰地向面试官解释“为什么比较第 k/2k/2 个元素就能排除掉一半”,你就已经成功拿下了这道 Hard 题。