二分查找

182 阅读3分钟
  • 概述

    • 这个模版没有leetcode book中的分三种情况处理,都是一个样式,搜索空间都是[lo,hi],最后退出搜索时都是lo==hi,只不过需要自己进行后处理;另外遇到lo=mi更新的时候会比较麻烦一些,因为向下取整的原因这样更新会陷入死循环,但幸运的是当出现lo==mi的情况时,搜索空间里只剩下两个数字了,多处理一下就好,这样我们可以面对二分的时候可以一以贯之
  • 704. 二分查找

    class Solution:
        def search(self, nums: List[int], target: int) -> int:
            lo, hi = 0, len(nums)-1
            while lo < hi:
                mi = (lo + hi)//2
                if nums[mi] == target:
                    return mi
                if target < nums[mi]:
                    hi = mi - 1
                else:
                    lo = mi + 1
            if nums[lo] == target:
                return lo
            return -1
    
  • 35. 搜索插入位置

    class Solution:
        def searchInsert(self, nums: List[int], target: int) -> int:
            lo, hi = 0, len(nums)-1
            while lo < hi:
                mi = (lo + hi)//2
                if nums[mi] == target:
                    return mi
                if target < nums[mi]:
                    hi = mi - 1
                else:
                    lo = mi + 1
            if target > nums[lo]:
                return lo + 1
            else:
                return lo
    
  • 852. 山脉数组的峰顶索引

    class Solution:
        def peakIndexInMountainArray(self, arr: List[int]) -> int:
            lo, hi = 0, len(arr)-1
            while lo < hi:
                mi = (lo + hi)//2
                if arr[mi] < arr[mi+1]:
                    lo = mi + 1
                else:
                    hi = mi
            return lo
    
  • 1385. 两个数组间的距离值

    class Solution:
        def findTheDistanceValue(self, arr1: List[int], arr2: List[int], d: int) -> int:
            arr2.sort()
            def bisect_left(target: int) -> int:
                lo, hi = 0, len(arr2)-1
                while lo < hi:
                    mi = (lo + hi)//2
                    if arr2[mi] == target:
                        hi = mi
                    elif target < arr2[mi]:
                        hi = mi - 1
                    else:
                        lo = mi + 1
                if target <= arr2[lo]:
                    return lo
                return lo + 1
    
            # 在arr2中找与arr[i]最相近的数,如果连最相近的数绝对值差都大于d,其他就更不用考虑
            ans = 0
            for i, v in enumerate(arr1):
                p = bisect_left(v)
                if p == len(arr2) or abs(v - arr2[p]) > d:
                    if p == 0 or abs(v - arr2[p-1]) > d:
                        ans += 1
            return ans
    
  • 69. x 的平方根

    class Solution:
        def mySqrt(self, x: int) -> int:
            lo, hi = 1, x
            while lo < hi:
                mi = (lo + hi)//2
                if mi * mi == x:
                    return mi
                if x < mi * mi:
                    hi = mi - 1
                else:
                    lo = mi + 1
            if lo * lo <= x:
                return lo
            else:
                return lo-1
    
  • 374. 猜数字大小

    # The guess API is already defined for you.
    # @param num, your guess
    # @return -1 if num is higher than the picked number
    #          1 if num is lower than the picked number
    #          otherwise return 0
    # def guess(num: int) -> int:
    
    class Solution:
        def guessNumber(self, n: int) -> int:
            lo, hi = 1, n
            while lo < hi:
                mi = (lo + hi)//2
                if guess(mi) == 0:
                    return mi
                if guess(mi) < 0:
                    hi = mi - 1
                else:
                    lo = mi + 1
            if guess(lo) == 0:
                return lo
            else:
                return -1
    
  • 441. 排列硬币

    class Solution:
        def arrangeCoins(self, n: int) -> int:
            lo, hi = 1, n
            while lo < hi:
                mi = (lo + hi)//2
                if mi * (mi+1) // 2 == n:
                    return mi
                if n < mi * (mi+1) // 2:
                    hi = mi - 1
                else:
                    if lo != mi:
                        lo = mi
                    else:
                        if n >= hi * (hi+1)//2:
                            return hi
                        else:
                            return lo
            return lo
    
  • 1539. 第 k 个缺失的正整数

    class Solution:
        def findKthPositive(self, arr: List[int], k: int) -> int:
            if arr[0] > k:
                return k
            lo, hi = 0, len(arr)-1
            # 对于每个元素ai,都可以唯一确定到第i个元素为止缺失的元素数量为ai-i-1
            while lo < hi:
                mi = (lo + hi)//2
                p = arr[mi] - mi - 1
                if p == k:
                    hi = mi
                elif p < k:
                    lo = mi + 1
                else:
                    hi = mi - 1
            if arr[lo] - lo - 1 < k:
                lo += 1
            return k - (arr[lo-1] - (lo-1) - 1) + arr[lo-1]
    
  • 1855. 下标对中的最大距离

    class Solution:
        def maxDistance(self, nums1: List[int], nums2: List[int]) -> int:
            ans = 0
            for i in range(len(nums1)):
                lo, hi = i, len(nums2)-1
                if lo > hi:
                    continue
                while lo < hi:
                    mi = (lo + hi)//2
                    if nums1[i] <= nums2[mi]:
                        if lo != mi:
                            lo = mi
                        else:
                            lo += 1
                    else:
                        hi = mi - 1
                if nums1[i] <= nums2[lo]:
                    ans = max(ans, lo - i)
                else:
                    ans = max(ans, lo-1 - i)
            return ans
    
  • 33. 搜索旋转排序数组

    class Solution:
        def search(self, nums: List[int], target: int) -> int:
            lo, hi = 0, len(nums)-1
            while lo < hi:
                mi = (lo + hi)//2
                if nums[mi] == target:
                    return mi
                if nums[lo] < nums[mi]:
                    if nums[lo] <= target < nums[mi]:
                        hi = mi - 1
                    else:
                        lo = mi + 1
                elif nums[lo] > nums[mi]:
                    if nums[mi] < target <= nums[hi]:
                        lo = mi + 1
                    else:
                        hi = mi - 1
                else:
                    lo += 1
            if nums[lo] == target:
                return lo
            return -1
    
  • 278. 第一个错误的版本

    # The isBadVersion API is already defined for you.
    # def isBadVersion(version: int) -> bool:
    
    class Solution:
        def firstBadVersion(self, n: int) -> int:
            lo, hi = 1, n
            while lo < hi:
                mi = (lo + hi)//2
                if isBadVersion(mi):
                    hi = mi
                else:
                    lo = mi + 1
            if isBadVersion(lo):
                return lo
            else:
                return -1
    
  • 162. 寻找峰值

    class Solution:
        def findPeakElement(self, nums: List[int]) -> int:
            lo, hi = 0, len(nums)-1
            while lo < hi:
                mi = (lo + hi)//2
                if nums[mi] < nums[mi+1]:
                    lo = mi + 1
                else:
                    hi = mi
            if lo in [0,len(nums)-1] or (nums[lo-1]<nums[lo] and nums[lo]>nums[lo+1]):
                return lo
            else:
                return -1
    
  • 153. 寻找旋转排序数组中的最小值

    class Solution:
        def findMin(self, nums: List[int]) -> int:
            lo, hi = 0, len(nums)-1
            while lo < hi:
                mi = (lo + hi)//2
                if nums[mi] > nums[mi+1]:
                    return nums[mi+1]
                if nums[lo] < nums[mi]:
                    lo = mi + 1
                elif nums[lo] > nums[mi]:
                    hi = mi
                else:
                    lo += 1
            if lo == len(nums)-1:
                return nums[0]
            return lo
    
  • 34. 在排序数组中查找元素的第一个和最后一个位置

    class Solution:
        def searchRange(self, nums: List[int], target: int) -> List[int]:
            if not nums: return [-1,-1]
            def binarySearchLeft() -> int:
                lo, hi = 0, len(nums)-1
                while lo < hi:
                    mi = (lo + hi)//2
                    if nums[mi] == target:
                        hi = mi
                    elif target < nums[mi]:
                        hi = mi - 1
                    else:
                        lo = mi + 1
                if nums[lo] == target:
                    return lo
                else:
                    return -1
    
            def binarySearchRight() -> int:
                lo, hi = 0, len(nums)-1
                while lo < hi:
                    mi = (lo + hi)//2
                    if nums[mi] == target:
                        if lo != mi:
                            lo = mi
                        else:
                            if nums[hi] == target:
                                return hi
                            else:
                                return lo
                    elif target < nums[mi]:
                        hi = mi - 1
                    else:
                        lo = mi + 1
                if nums[lo] == target:
                    return lo
                else:
                    return -1
    
            return [binarySearchLeft(), binarySearchRight()]
    
  • 658. 找到 K 个最接近的元素

        class Solution:
        def findClosestElements(self, arr: List[int], k: int, x: int) -> List[int]:
            if x <= arr[0]: return arr[:k]
            if x >= arr[-1]: return arr[-k:]
            lo, hi = 0, len(arr)-k
            while lo < hi:
                mi = (lo + hi)//2
                if abs(arr[mi]-x) <= abs(arr[mi+k]-x):
                    hi = mi
                else:
                    lo = mi + 1
            return arr[lo:lo+k]
    
  • 702. 搜索长度未知的有序数组

    # """
    # This is ArrayReader's API interface.
    # You should not implement it, or speculate about its implementation
    # """
    #class ArrayReader:
    #    def get(self, index: int) -> int:
    
    class Solution:
        def search(self, reader: 'ArrayReader', target: int) -> int:
            OVERFLOW = 2**31 - 1
            def getArrLen() -> int:
                index = 1
                while reader.get(index) != OVERFLOW:
                    index *= 2
                lo, hi = 0, index
                while lo < hi:
                    mi = (lo + hi)//2
                    if reader.get(mi) == OVERFLOW:
                        hi = mi - 1
                    else:
                        if lo == mi:
                            if reader.get(hi) != OVERFLOW:
                                return hi + 1
                            else:
                                return lo + 1
                        else:
                            lo = mi
                return lo + 1
    
            lo, hi = 0, getArrLen()-1
            while lo < hi:
                mi = (lo + hi)//2
                if reader.get(mi) == target: return mi
                if reader.get(mi) < target: lo = mi + 1
                else: hi = mi - 1
            if reader.get(lo) == target: return lo
            else: return -1
    
  • 367. 有效的完全平方数

    class Solution:
        def isPerfectSquare(self, num: int) -> bool:
            lo, hi = 1, num
            while lo < hi:
                mi = (lo + hi)//2
                if mi*mi == num:
                    return True
                if num < mi*mi:
                    hi = mi - 1
                else:
                    lo = mi + 1
            if lo*lo == num:
                return True
            else:
                return False
    
  • 744. 寻找比目标字母大的最小字母

    class Solution:
        def nextGreatestLetter(self, letters: List[str], target: str) -> str:
            lo, hi = 0, len(letters)-1
            while lo < hi:
                mi = (lo + hi)//2
                if letters[mi] <= target:
                    lo = mi + 1
                else:
                    hi = mi
            if letters[lo] > target:
                return letters[lo]
            else:
                return letters[0]
    
  • 154. 寻找旋转排序数组中的最小值 II

    class Solution:
        def findMin(self, nums: List[int]) -> int:
            lo, hi = 0, len(nums)-1
            ans = nums[lo]
            while lo < hi:
                mi = (lo + hi)//2
                ans = min(ans, nums[lo])
                if nums[mi] > nums[mi+1]:
                    return nums[mi+1]
                if nums[lo] < nums[mi]:
                    lo = mi + 1
                elif nums[lo] > nums[mi]:
                    hi = mi
                else:
                    lo += 1
            if lo < len(nums)-1:
                return nums[lo]
            else:
                return ans
    
  • 167. 两数之和 II - 输入有序数组

    class Solution:
        def twoSum(self, numbers: List[int], target: int) -> List[int]:
            def binarySearch(lo: int, hi: int, target: int) -> int:
                while lo < hi:
                    mi = (lo + hi)//2
                    if numbers[mi] == target:
                        return mi
                    if target < numbers[mi]:
                        hi = mi - 1
                    else:
                        lo = mi + 1
                if numbers[lo] == target:
                    return lo
                else:
                    return -1
    
            n = len(numbers)
            for i in range(n-1):
                index = binarySearch(i+1, n-1, target-numbers[i])
                if index != -1:
                    return [i+1, index+1]
            return [-1,-1]
    
  • 1608. 特殊数组的特征值

    class Solution:
        def specialArray(self, nums: List[int]) -> int:
            lo, hi = 0, len(nums)
            while lo < hi:
                mi = (lo + hi)//2
                # ge: greater than or equal to 大于等于
                geMidCnt = sum([1 for num in nums if num >= mi])
                if geMidCnt == mi:
                    return mi
                if geMidCnt < mi:
                    hi = mi - 1
                else:
                    lo = mi + 1
            if sum([1 for num in nums if num >= lo]) == lo:
                return lo
            return -1
    
  • 287. 寻找重复数

    class Solution:
        def findDuplicate(self, nums: List[int]) -> int:
            lo, hi = 1, len(nums)-1
            while lo < hi:
                mi = (lo + hi)//2
                cnt = sum(1 for num in nums if num <= mi)
                if cnt > mi:
                    hi = mi
                else:
                    lo = mi + 1
            return lo
    
  • 719. 找出第 K 小的数对距离

    class Solution:
        def smallestDistancePair(self, nums: List[int], k: int) -> int:
            def possible(guess: int) -> bool:
                count = left = 0
                for right, num in enumerate(nums):
                    while num - nums[left] > guess:
                        left += 1
                    count += right - left
                return count >= k
    
            nums.sort()
            lo, hi = 0, nums[-1]-nums[0]
            while lo < hi:
                mi = (lo + hi)//2
                if possible(mi):
                    hi = mi
                else:
                    lo = mi + 1
            return lo
    
  • 410. 分割数组的最大值

    class Solution:
        def splitArray(self, nums: List[int], m: int) -> int:
            def possible(guess: int) -> bool:
                total, cnt = 0, 1
                for num in nums:
                    if total + num > guess:
                        cnt += 1
                        total = num
                    else:
                        total += num
                return cnt <= m
    
            lo, hi = max(nums), sum(nums)
            while lo < hi:
                mi = (lo + hi)//2
                if possible(mi):
                    hi = mi
                else:
                    lo = mi + 1
            return lo
    
  • 1351. 统计有序矩阵中的负数

    class Solution:
        def countNegatives(self, grid: List[List[int]]) -> int:
            m, n = len(grid), len(grid[0])
            ans = 0
            r, c = 0, n-1
            while r < m and c >= 0:
                if grid[r][c] >= 0:
                    r += 1
                else:
                    ans += m - r
                    c -= 1
            return ans
    
  • 74. 搜索二维矩阵

    class Solution:
        def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
            m, n = len(matrix), len(matrix[0])
            lo, hi = 0, m*n-1
            while lo < hi:
                mi = (lo + hi)//2
                if matrix[mi//n][mi%n] == target:
                    return True
                if matrix[mi//n][mi%n] < target:
                    lo = mi + 1
                else:
                    hi = mi - 1
            if matrix[lo//n][lo%n] == target:
                return True
            return False
    
  • 240. 搜索二维矩阵 II

    class Solution:
        def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
            m, n = len(matrix), len(matrix[0])
            r, c = 0, n-1
            while r < m and c >= 0:
                if matrix[r][c] == target:
                    return True
                if matrix[r][c] > target:
                    c -= 1
                else:
                    r += 1
            return False
    
  • 1337. 矩阵中战斗力最弱的 K 行

    class Solution:
        def kWeakestRows(self, mat: List[List[int]], k: int) -> List[int]:
            m, n = len(mat), len(mat[0])
            power = []
            for i in range(m):
                lo, hi = 0, n-1
                while lo < hi:
                    mi = (lo + hi)//2
                    if mat[i][mi] == 0:
                        hi = mi - 1
                    else:
                        if lo != mi:
                            lo = mi
                        else:
                            lo += 1
                if mat[i][lo] == 1:
                    power.append((lo+1, i))
                else:
                    power.append((lo, i))
    
            heapq.heapify(power)
            ans = []
            for _ in range(k):
                ans.append(heapq.heappop(power)[1])
            return ans