[LeetCode] 4. Median of Two Sorted Arrays 题解

3,684 阅读3分钟

问题描述

两个有序的数组 nums1nums2 ,它们的数组长度分别为 m 和 n。要求找到这两个数组的中位数,且总体的时间复杂度必须为 O(\log(m+n))

假设 nums1nums2 都不为空。

例 1:

nums1 = [1, 3]
nums2 = [2]
中位数为 2.0

例 2:

nums1 = [1, 2]
nums2 = [3, 4]
中位数为 (2 + 3)/2 = 2.5

问题难度

Hard

解题思路

做第一遍时,使用的遍历的办法,即按照从小到大的顺序,分别遍历两个队列,一直遍历到 (m+n)/2 的位置,即是我们所要的中位数,当时心想,这也太简单了吧,Hard 级别也不过如此。

再仔细看题目,它还有个时间复杂度为 O(\log(m+n)) 的要求,而我的方案的时间复杂度却为 O(n),并不满足题目条件,可见,单这个条件就值 Hard 级别了。

解这道题的关键并不是高超的算法,而是心中要有一副这样的图:

              left side       |  right side 
nums1:   A(0),A(1),...,A(i-1) | A(i),...,A(m-1)
nums2:   B(0),B(1),...,B(j-1) | B(j),...,B(n-1)

我们把两个数组看成一个整体,有一根竖线将其中的元素从中间分开,且左边的部分和右边部分的元素相同(总数为奇数情况下,左边比右边多 1 个元素),那么当 m+n 为偶数时,中位数为 [max(A(i-1),B(j-1)) + min(A(i)+B(j))]\div2 ,当 m+n 为奇数时,中位数为 max(A(i-1),B(j-1))

我们都知道,左边的元素为 i+j 个,而左右两边元素相同,则

i + j = \frac{m+n+1}{2}

我们可以用 i 来表示 j,则

j = \frac{m+n+1}{2} - i

所以,该题就变成了,在数组 A 中寻找一个 i,使得 A(i) \ge B(j-1),且 B(j) \ge A(i-1)​ 成立,这两个不等式的含义是,竖线右边最小的数一定不比左边最大的数小,满足该条件,我们就可以说找到了这个竖线。

我们在找 i 的过程中,难免会碰到 A(i) < B(j-1) 时候,此时我们需要将 i 向右移,使 A(i) 增大,当 i 右移,i 增大的同时,j 会减少,即 B(j-1) 的值会变小,这样操作 i 之后,会让我们更接近目标;同理,当 B(j) < A(i-1) 时,我们需要将 i 向左移。

通过上面的分析,我们最终可以使用二分查找法来寻找这个 i 值,又由于二分查找的时间复杂度为 O(\log(n))​,这种方法可以满足题目的要求。

思路说完了,下面来说下该题目的边界条件,由于 j 是通过减去 i 算出来的,而 i 的最大值为 m(A 全在左边时),所以为了使 j 不为负数,数组 A 需要为两个数组中,元素数较少的那个。

当 i 为 0 时,数组 A 全在右边,我们只需要判断 A(i) \ge B(j-1)​ 成立;当 i 为 m 时,数组 A 全在左边,只需判断 B(j) \ge A(i-1)​ 成立

同理当 j 为 0 时,数组 B 全在右边,我们只需判断 B(j) \ge A(i-1) 成立;当 j 为 n 时,数组 B 全在左边,只需判断 A(i) \ge B(j-1) 即可

于是,我们可以写出下面的代码:

class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        m, n = len(nums1), len(nums2)
        if m > n:
            m, n, nums1, nums2 = n, m, nums2, nums1

        if m == 0 and n == 0:
            return None

        begin = 0
        end = m
        i = j = 0
        while True:
            i = (begin + end) / 2
            j = (m + n + 1) / 2 - i

            if (i == 0 or j == n or nums2[j] >= nums1[i-1]) and\
                    (i == m or j == 0 or nums1[i] >= nums2[j-1]):
                left_max = 0
                if i == 0: left_max = nums2[j-1]
                elif j == 0: left_max = nums1[i-1]
                else: left_max = max(nums1[i-1],nums2[j-1])
                
                if (m+n)%2 != 0:
                    return left_max

                right_min = 0
                if i == m: right_min = nums2[j]
                elif j == n: right_min = nums1[i]
                else: right_min = min(nums1[i], nums2[j])

                return (left_max + right_min)*1.0/2

            elif j < n and i > 0 and nums2[j] < nums1[i-1]:
                end = i - 1
            elif j > 0 and i < n and nums1[i] < nums2[j-1]:
                begin = i + 1

这道题直接看代码是很难理解的,但如果心中有前文说的那张图,便可以沿着思路慢慢化解,可见要达到题目的要求并不简单,Hard 难度名不虚传。

原题链接