LeetCode 4-寻找两个正序数组的中位数

60 阅读4分钟

题目

给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。

算法的时间复杂度应该为 O(log (m+n)) 。

示例 1:

输入: nums1 = [1,3], nums2 = [2]
输出: 2.00000
解释: 合并数组 = [1,2,3] ,中位数 2

示例 2:

输入: nums1 = [1,2], nums2 = [3,4]
输出: 2.50000
解释: 合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

提示:

  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -106 <= nums1[i], nums2[i] <= 106

思路

解法一: 归并排序合并+直接查找中位值

归并排序合并数组,O(m+n),在新的数组中,判断数组的奇偶,查找中位值。

代码一:归并排序合并+直接查找

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        nums = merge(nums1, nums2)
        n = len(nums)
        if n % 2 == 1:
            return nums[n>>1]
        return (nums[n>>1] + nums[(n>>1)-1]) * 0.5
    

def merge(n1, n2):
    i = 0
    j = 0
    res = []
    while i < len(n1) and j < len(n2):
        if n1[i] <= n2[j]:
            res.append(n1[i])
            i += 1
        else:
            res.append(n2[j])
            j += 1
    while i < len(n1):
        res.append(n1[i])
        i += 1
    while j < len(n2):
        res.append(n2[j])
        j += 1
    return res

解法二:双指针直接查找中位数

双指针i,j分别从nums1、nums2的头开始,不断比较大小,查找(n+m)>>1和((n+m)>>1)-1的位置。

查询时候,边界应该是 (n+m)>>1,i1 i2 分别从nums1 nums2开始

  • 当nums[i1] <= nums2[i2]时,i1++
  • 当nums[i1] <= nums2[i2]时,i2++
  • 考虑i1和i2可能存在越界的情况
    • 只有当i1<len(nums1)并且nums[i1] <= nums2[i2]时,i1++;
    • 如果i2>=len(nums2),此时nums2空了,只能在nums1中取值。

整理一下,条件应该为(i1 < len(nums1) and i2 >= len(nums2)) or (i1 < len(nums1) and nums[i1] <= nums2[i2]),根据逻辑运算的分配性,i1 < len(nums1) and (i2 >= len(nums2) or nums[i1] <= nums2[i2]),注意条件判断是否越界在前,值判断在后,避免出现索引越界情况。

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        n = len(nums1)
        m = len(nums2)
        left  = None
        right = None
        i1 = 0
        i2 = 0
        l = n + m
        # 最多查询中位数次
        for i in range((l>>1)+1):
            left = right
            # i1 < n 且 (i2越界 or i1值较小)
            if i1 < n and (i2 >= m or nums1[i1] <= nums2[i2]):
                right = nums1[i1]
                i1 += 1
            else:
                right = nums2[i2]
                i2 += 1
        if  l % 2 == 1:
            return right
        return (left + right) * 0.5

解法三:二分法排除k/2

解法一和解法二的时间复杂度都是O(m+n),想要时间O(log(m+n)),就需要用到二分法。

题意可以理解为找到第k位小的数。解法二中的一次遍历去掉不可能是中位数的一个,一个一个排除。基于数组序列是有序的,可以实现一半一半排除。

假设我们要找第 k 小数,我们可以每次循环排除掉 k/2 个数。具体步骤如下:

  • k/2 进行(向下)取整,比较 nums1[k/2] 和 nums2[k/2]的值
  • 不妨设 nums1[k/2] < nums2[k/2],那么说明 nums1[k/2] 肯定不是第k位小的数,一定在nums1[k/2+1,...]和nums2[...]中
  • 排除掉nums1[k/2]后,需要寻找的是k-nums1[k/2]位的数,然后nums1区间nums1[k/2+1,...]
  • ...直到k==1,此时只需要比较nums1和nums2的所在区间的第一位数字就是所求的答案。
class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        n = len(nums1)
        m = len(nums2)
        left = (n+m+1) >> 1
        right = (n+m+2) >> 1
        return (getKthNum(nums1, 0, n-1, nums2, 0, m-1, left) +  getKthNum(nums1, 0, n-1, nums2, 0, m-1, right)) * 0.5


def getKthNum(n1, start1, end1, n2, start2, end2, k):
    l1 = end1 - start1 + 1
    l2 = end2 - start2 + 1
    if l1 > l2:
        return getKthNum(n2, start2, end2, n1, start1, end1, k)
    if l1 == 0:
        return n2[start2 + k -1]
    if k == 1:
        return min(n1[start1], n2[start2])
    
    i = start1 + min(l1, k>>1) - 1
    j = start2 + min(l2, k>>1) - 1
    if n1[i] < n2[j]:
        return getKthNum(n1, i+1, end1, n2, start2, end2, k - (i-start1+1))
    else:
        return getKthNum(n1, start1, end1, n2, j+1, end2, k - (j-start2+1))