递归和分治

145 阅读8分钟
  • 概述:

    • 三部曲:递归终止条件 -> 拆分 -> 合并;
    • 理论:主定理,随机化
  • 自顶向下与自底向上

    • 50. Pow(x, n)

      class Solution:
          def myPow(self, x: float, n: int) -> float:
              # 递归终止条件
              if n == 0:
                  return 1
              if n < 0:
                  return 1 / self.myPow(x, -n)
              # 拆分 + 合并
              if n % 2:
                  return x * self.myPow(x, n-1)
              return self.myPow(x*x, n//2)
      
    • 剑指 Offer 65. 不用加减乘除做加法

      class Solution:
          def add(self, a: int, b: int) -> int:
              # 因为python中特殊的表现形式,需要先将两个数转换成32位的补码进行计算
              a, b = a & 0xffffffff, b & 0xffffffff
              # 递归终止条件
              if b == 0:
                  # # 这种方式也可以
                  # return a if a <= 0x7fffffff else ~(a ^ 0xffffffff)
                  # 运算结束,如果a是正数,直接返回结果;否则需要获取32位负数补码的打印形式“-(~补码+1)”
                  return a if a <= 0x7fffffff else -((~(a&0x7fffffff) & 0x7fffffff) + 1)
              # 拆分 + 合并
              c = (a & b) << 1
              return self.add(a^b, c)
      
    • 面试题 08.05. 递归乘法

      class Solution:
          def multiply(self, A: int, B: int) -> int:
              if A > B:
                  return self.multiply(B, A)
              # 递归终止条件
              if A == 0:
                  return 0
              # 拆分 + 合并
              return self.multiply(A-1, B) + B
      
    • 53. 最大子数组和

      class Solution:
          def maxSubArray(self, nums: List[int]) -> int:
              # a: [lo, hi]中始于lo的最大和
              # m: [lo, hi]中区间的最大和
              # b: [lo, hi]中始于hi的最大和
              # s: [lo, hi]中区间和
              def div_n_conq(lo: int, hi: int) -> Tuple[int,int,int,int]:
                  if lo >= hi:
                      return (nums[lo], nums[lo], nums[lo], nums[lo])
                  mi = (lo + hi)//2
                  a1, b1, m1, s1 = div_n_conq(lo, mi)
                  a2, b2, m2, s2 = div_n_conq(mi+1, hi)
                  a = max(a1, s1+a2)
                  b = max(b2, s2+b1)
                  m = max(m1, m2, b1+a2)
                  s = s1 + s2
                  return (a, b, m, s)
      
              _, _, m, _ = div_n_conq(0, len(nums)-1)
              return m
      
  • 深入理解递归 - 1:

    • 剑指 Offer 51. 数组中的逆序对

      class Solution:
          def reversePairs(self, nums: List[int]) -> int:
              n = len(nums)
              tmp = [0] * n
              ans = 0
              def merge(lo: int, mi: int, hi: int):
                  nonlocal ans
                  i, j, p = lo, mi+1, lo
                  while i <= mi and j <= hi:
                      if nums[i] <= nums[j]:
                          tmp[p] = nums[i]
                          # 因为j指针所指向的组j下标之前的走已经先于i被合并了,
                          # 所以他们肯定比nums[i]小
                          ans += (j-1) - (mi+1) + 1
                          i, p = i + 1, p + 1
                      else:
                          tmp[p] = nums[j]
                          j, p = j + 1, p + 1
                  while i <= mi:
                      tmp[p] = nums[i]
                      # 即使j指向的组已经空了也是一样,因为j组在i组之后,所以他们都是逆序对
                      ans += (j-1) - (mi+1) + 1
                      i, p = i + 1, p + 1
                  while j <= hi:
                      tmp[p] = nums[j]
                      j, p = j + 1, p + 1
                  for k in range(lo, hi+1):
                      nums[k] = tmp[k]
      
              def mergeSort(lo: int, hi: int):
                  if lo >= hi:
                      return
                  mi = (lo + hi)//2
                  mergeSort(lo, mi)
                  mergeSort(mi+1, hi)
                  merge(lo, mi, hi)
      
              mergeSort(0, n-1)
              return ans
      
    • 315. 计算右侧小于当前元素的个数

      class Solution:
          def countSmaller(self, nums: List[int]) -> List[int]:
              n = len(nums)
              # index数组存放原来每个数所在的位置,方便在归并过程中每个数加上右侧比他小的数的个数
              # tmp数组存放两两归并数组的临时结果
              index, tmp, tmpIdx, ans = list(range(n)), [0]*n, [0]*n, [0]*n
              def merge(lo: int, mi: int, hi: int):
                  i, j, p = lo, mi+1, lo
                  while i <= mi and j <= hi:
                      if nums[i] <= nums[j]:
                          tmp[p] = nums[i]
                          tmpIdx[p] = index[i]
                          # 因为j指针所指向的组j下标之前的走已经先于i被合并了,
                          # 所以他们肯定比nums[i]小
                          ans[index[i]] += (j-1) - (mi+1) + 1
                          i, p = i + 1, p + 1
                      else:
                          tmp[p] = nums[j]
                          tmpIdx[p] = index[j]
                          j, p = j + 1, p + 1
                  while i <= mi:
                      tmp[p] = nums[i]
                      tmpIdx[p] = index[i]
                      # 即使j指向的组已经空了也是一样,因为j组在i组之后,所以他们都是逆序对
                      ans[index[i]] += (j-1) - (mi+1) + 1
                      i, p = i + 1, p + 1
                  while j <= hi:
                      tmp[p] = nums[j]
                      tmpIdx[p] = index[j]
                      j, p = j + 1, p + 1
                  for k in range(lo, hi+1):
                      nums[k] = tmp[k]
                      index[k] = tmpIdx[k]
      
              def mergeSort(lo: int, hi: int):
                  if lo >= hi:
                      return
                  mi = (lo + hi)//2
                  mergeSort(lo, mi)
                  mergeSort(mi+1, hi)
                  merge(lo, mi, hi)
      
              mergeSort(0, n-1)
              return ans
      
    • 493. 翻转对

      class Solution:
          def reversePairs(self, nums: List[int]) -> int:
              n = len(nums)
              tmp = [0] * n
              ans = 0
              def merge(lo: int, mi: int, hi: int):
                  nonlocal ans
                  # 区间[lo,mi]和[mi+1,hi]已经有序,可计算重要翻转对
                  ii, jj = lo, mi+1
                  while ii <= mi and jj <= hi:
                      if nums[ii] <= 2 * nums[jj]:
                          ii += 1
                      else:
                          # 如果num[ii]>2*nums[jj],那么[ii,mi]的数字也都符合
                          ans += mi - ii + 1
                          jj += 1
                  # 剩下的正常的归并排序的步骤还是要继续,保证之后的左右区间有序
                  i, j, p = lo, mi+1, lo
                  while i <= mi and j <= hi:
                      if nums[i] <= nums[j]:
                          tmp[p] = nums[i]
                          i, p = i + 1, p + 1
                      else:
                          tmp[p] = nums[j]
                          j, p = j + 1, p + 1
                  while i <= mi:
                      tmp[p] = nums[i]
                      i, p = i + 1, p + 1
                  while j <= hi:
                      tmp[p] = nums[j]
                      j, p = j + 1, p + 1
                  for k in range(lo, hi+1):
                      nums[k] = tmp[k]
      
              def mergeSort(lo: int, hi: int):
                  if lo >= hi:
                      return
                  mi = (lo + hi)//2
                  mergeSort(lo, mi)
                  mergeSort(mi+1, hi)
                  merge(lo, mi, hi)
      
              mergeSort(0, n-1)
              return ans
      
    • 215. 数组中的第K个最大元素

      class Solution:
          def findKthLargest(self, nums: List[int], k: int) -> int:
              # 返回lo, mi, hi三个位置都数字排列后的中间的数的坐标mi,以求尽可能的平均分成左右两部
              def median3(lo: int, hi: int) -> int:
                  mi = (lo + hi)//2
                  if nums[lo] > nums[mi]:
                      nums[lo], nums[mi] = nums[mi], nums[lo]
                  if nums[lo] > nums[hi]:
                      nums[lo], nums[hi] = nums[hi], nums[lo]
                  if nums[mi] > nums[hi]:
                      nums[mi], nums[hi] = nums[hi], nums[mi]
                  return mi
      
              # 返回主元pivot在数组中的的坐标,区间[lo,hi]根据pivot分成左右,
              # 左边的数都<pivot,右边的数都>=pivot,pivot的坐标已经最终确定
              def partition(lo: int, hi: int) -> int:
                  # pivotIdx = randint(lo, hi)
                  pivotIdx = median3(lo, hi)
                  nums[pivotIdx], nums[hi] = nums[hi], nums[pivotIdx]
                  i = lo - 1
                  for j in range(lo, hi):
                      if nums[j] < nums[hi]:
                          i += 1
                          nums[j], nums[i] = nums[i], nums[j]
                  i += 1
                  nums[i], nums[hi] = nums[hi], nums[i]
                  return i
      
              # 返回原数组中从小到大排第len(nums)-k个数
              def quickSelect(lo: int, hi: int):
                  if lo >= hi:
                      return
                  q = partition(lo, hi)
                  if q == len(nums)-k:
                      return
                  elif q < len(nums)-k:
                      quickSelect(q+1, hi)
                  else:
                      quickSelect(lo, q-1)
      
              quickSelect(0, len(nums)-1)
              return nums[-k]
      
    • 347. 前 K 个高频元素

      class Solution:
          def topKFrequent(self, nums: List[int], k: int) -> List[int]:
              count = Counter(nums)
              nums = list(count.keys())
              def median3(lo: int, hi: int) -> int:
                  mi = (lo + hi)//2
                  if count[nums[lo]] > count[nums[mi]]:
                      nums[lo], nums[mi] = nums[mi], nums[lo]
                  if count[nums[lo]] > count[nums[hi]]:
                      nums[lo], nums[hi] = nums[hi], nums[lo]
                  if count[nums[mi]] > count[nums[hi]]:
                      nums[mi], nums[hi] = nums[hi], nums[mi]
                  return mi
      
              def partition(lo: int, hi: int) -> int:
                  pivotIdx = median3(lo, hi)
                  nums[pivotIdx], nums[hi] = nums[hi], nums[pivotIdx]
                  i = lo - 1
                  for j in range(lo, hi):
                      if count[nums[j]] < count[nums[hi]]:
                          i += 1
                          nums[j], nums[i] = nums[i], nums[j]
                  i += 1
                  nums[i], nums[hi] = nums[hi], nums[i]
                  return i
      
              # 找到arr中元素从小到大的下标
              def quickSelect(lo: int, hi: int):
                  if lo >= hi:
                      return
                  q = partition(lo, hi)
                  if q == len(nums)-k:
                      return
                  elif q < len(nums)-k:
                      quickSelect(q+1, hi)
                  else:
                      quickSelect(lo, q-1)
      
              quickSelect(0, len(nums)-1)
              return nums[-k:]
      
    • 973. 最接近原点的 K 个点

      class Solution:
          def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
              def partition(lo: int, hi: int):
                  pivotIdx = randint(lo, hi)
                  pivot = points[pivotIdx][0] ** 2 + points[pivotIdx][1] ** 2
                  points[pivotIdx], points[hi] = points[hi], points[pivotIdx]
                  i = lo - 1
                  for j in range(lo, hi):
                      if points[j][0] ** 2 + points[j][1] ** 2 <= pivot:
                          i += 1
                          points[j], points[i] = points[i], points[j]
                  i += 1
                  points[i], points[hi] = points[hi], points[i]
                  return i
      
              def quickSelect(lo: int, hi: int):
                  if lo >= hi:
                      return
                  q = partition(lo, hi)
                  if q == k-1:
                      return
                  elif q < k-1:
                      quickSelect(q+1, hi)
                  else:
                      quickSelect(lo, q-1)
      
              quickSelect(0, len(points)-1)
              return points[:k]
      
  • 深入理解递归 - 2

    • 21. 合并两个有序链表

      # Definition for singly-linked list.
      # class ListNode:
      #     def __init__(self, val=0, next=None):
      #         self.val = val
      #         self.next = next
      class Solution:
          def mergeTwoLists(self, list1: Optional[ListNode], list2: Optional[ListNode]) -> Optional[ListNode]:
              if not list1:
                  return list2
              if not list2:
                  return list1
              if list1.val <= list2.val:
                  list1.next = self.mergeTwoLists(list1.next, list2)
                  return list1
              list2.next = self.mergeTwoLists(list1, list2.next)
              return list2
      
    • 206. 反转链表

      # Definition for singly-linked list.
      # class ListNode:
      #     def __init__(self, val=0, next=None):
      #         self.val = val
      #         self.next = next
      class Solution:
          def reverseList(self, head: Optional[ListNode]) -> Optional[ListNode]:
              if not head:
                  return head
              if not head.next:
                  return head
              ret = self.reverseList(head.next)
              head.next.next = head
              head.next = None
              return ret
      
    • 203. 移除链表元素

      # Definition for singly-linked list.
      # class ListNode:
      #     def __init__(self, val=0, next=None):
      #         self.val = val
      #         self.next = next
      class Solution:
          def removeElements(self, head: Optional[ListNode], val: int) -> Optional[ListNode]:
              if not head:
                  return head
              if head.val == val:
                  return self.removeElements(head.next, val)
              head.next = self.removeElements(head.next, val)
              return head
      
    • 24. 两两交换链表中的节点

      # Definition for singly-linked list.
      # class ListNode:
      #     def __init__(self, val=0, next=None):
      #         self.val = val
      #         self.next = next
      class Solution:
          def swapPairs(self, head: Optional[ListNode]) -> Optional[ListNode]:
              if not head:
                  return head
              if not head.next:
                  return head
              rest = head.next.next
              ret = self.swapPairs(rest)
              newHead = head.next
              head.next.next = head
              head.next = ret
              return newHead
      
    • 143. 重排链表

      # Definition for singly-linked list.
      # class ListNode:
      #     def __init__(self, val=0, next=None):
      #         self.val = val
      #         self.next = next
      class Solution:
          def reorderList(self, head: ListNode) -> None:
              # FILO后进先出栈式的返回还没处理的链表的头节点
              def filo(head: ListNode, tail: ListNode) -> ListNode:
                  if not tail:
                      return head
                  ret = filo(head, tail.next)
                  # 递归终止条件一:如果所有链表节点都已经得到处理时
                  if not ret:
                      return None
                  # 递归终止条件二:如果只剩下最后一个或一对节点未处理时
                  if ret == tail or ret.next == tail:
                      tail.next = None
                      return None
                  # 此时ret,tail分别指向还未处理的链表头和链表尾
                  tail.next = ret.next
                  ret.next = tail
                  return tail.next
      
              filo(head, head)
      
    • 92. 反转链表 II

      # Definition for singly-linked list.
      # class ListNode:
      #     def __init__(self, val=0, next=None):
      #         self.val = val
      #         self.next = next
      class Solution:
          def reverseBetween(self, head: Optional[ListNode], left: int, right: int) -> Optional[ListNode]:
              # 如果链表中从第一个开始是蓝且只有一个蓝(蓝色节点:待反转;白色节点:不需反转)
              if left == right:
                  return head
              # 如果链表中从第一个开始是白,递归一
              if left > 1:
                  newHead = head
                  newHead.next = self.reverseBetween(head.next, left-1, right-1)
                  return newHead
              # 如果链表中第一个开始是蓝且不只有一个蓝,递归二
              else:
                  nxt = head.next
                  newHead = self.reverseBetween(nxt, 1, right-1)
                  nxtnxt = nxt.next
                  nxt.next = head
                  head.next = nxtnxt
                  return newHead
      
  • 深入理解递归 - 3

    • 105. 从前序与中序遍历序列构造二叉树

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def buildTree(self, preorder: List[int], inorder: List[int]) -> Optional[TreeNode]:
              val2idx = {v: i for i, v in enumerate(inorder)}
              def div_n_conq(inL: int, inR: int, preL: int, preR: int) -> Optional[TreeNode]:
                  if inL > inR:
                      return None
                  rootVal = preorder[preL]
                  idx = val2idx[rootVal]
                  leftCnt = idx - inL
                  leftRet = div_n_conq(inL, idx-1, preL+1, preL+1+leftCnt-1)
                  rightRet = div_n_conq(idx+1, inR, preL+leftCnt+1, preR)
                  return TreeNode(rootVal, leftRet, rightRet)
      
              n = len(preorder)
              return div_n_conq(0, n-1, 0, n-1)
      
    • 106. 从中序与后序遍历序列构造二叉树

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def buildTree(self, inorder: List[int], postorder: List[int]) -> Optional[TreeNode]:
              val2idx = {v: i for i, v in enumerate(inorder)}
              def div_n_conq(inL: int, inR: int, postL: int, postR: int) -> Optional[TreeNode]:
                  if inL > inR:
                      return None
                  rootVal = postorder[postR]
                  idx = val2idx[rootVal]
                  leftCnt = idx - inL
                  leftRet = div_n_conq(inL, idx-1, postL, postL+leftCnt-1)
                  rightRet = div_n_conq(idx+1, inR, postL+leftCnt, postR-1)
                  return TreeNode(rootVal, leftRet, rightRet)
      
              n = len(inorder)
              return div_n_conq(0, n-1, 0, n-1)
      
    • 108. 将有序数组转换为二叉搜索树

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def sortedArrayToBST(self, nums: List[int]) -> Optional[TreeNode]:
              def div_n_conq(left: int, right: int) -> Optional[TreeNode]:
                  if left > right:
                      return None
                  mid = (left + right)//2
                  rootVal = nums[mid]
                  leftRet = div_n_conq(left, mid-1)
                  rightRet = div_n_conq(mid+1, right)
                  return TreeNode(rootVal, leftRet, rightRet)
      
              return div_n_conq(0, len(nums)-1)
      
    • 98. 验证二叉搜索树

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def isValidBST(self, root: Optional[TreeNode]) -> bool:
              ans = []
              def dfs(node: Optional[TreeNode]):
                  if not node:
                      return
                  dfs(node.left)
                  ans.append(node.val)
                  dfs(node.right)
      
              dfs(root)
              for i in range(len(ans)-1):
                  if ans[i] >= ans[i+1]:
                      return False
              return True
      
    • 104. 二叉树的最大深度

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def maxDepth(self, root: Optional[TreeNode]) -> int:
              # 可以提前确定的情况,先定下来,算一种剪枝
              if not root:
                  return 0
              leftRet = self.maxDepth(root.left)
              rightRet = self.maxDepth(root.right)
              return max(leftRet, rightRet) + 1
      
    • 110. 平衡二叉树

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def isBalanced(self, root: Optional[TreeNode]) -> bool:
              # 若以node为根的子树是平衡树,返回其高度,否则返回-1
              def dfs(node: Optional[TreeNode]) -> int:
                  if not node:
                      return 0
                  leftRet = dfs(node.left)
                  rightRet = dfs(node.right)
                  if leftRet == -1 or rightRet == -1 or abs(leftRet - rightRet) > 1:
                      return -1
                  return max(leftRet, rightRet) + 1
      
              return dfs(root) >= 0
      
    • 124. 二叉树中的最大路径和

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def maxPathSum(self, root: Optional[TreeNode]) -> int:
              ans = -inf
              # dfs返回必定经过node且node为端点的最大路径和
              def dfs(node: Optional[TreeNode]) -> int:
                  if not node:
                      return 0
                  leftRet = max(dfs(node.left), 0)
                  rightRet = max(dfs(node.right), 0)
                  nonlocal ans
                  ans = max(ans, leftRet + rightRet + node.val)
                  return max(leftRet, rightRet) + node.val
      
              dfs(root)
              return ans
      
    • 199. 二叉树的右视图

      # Definition for a binary tree node.
      # class TreeNode:
      #     def __init__(self, val=0, left=None, right=None):
      #         self.val = val
      #         self.left = left
      #         self.right = right
      class Solution:
          def rightSideView(self, root: Optional[TreeNode]) -> List[int]:
              depth1st = {}
              ans = []
              # 遍历顺序:根 - 右 - 左
              def dfs(node: Optional[TreeNode], depth: int):
                  if not node:
                      return
                  if depth not in depth1st:
                      ans.append(node.val)
                      depth1st[depth] = node.val
                  dfs(node.right, depth + 1)
                  dfs(node.left, depth + 1)
      
              dfs(root, 0)
              return ans
      
  • 总结

    • 10. 正则表达式匹配

      class Solution:
          @cache
          def isMatch(self, s: str, p: str) -> bool:
              if not s and not p:
                  return True
              if not p:
                  return False
              if not s:
                  return self.isMatch(s, p[2:]) if len(p)>=2 and p[1]=='*' else False
              if p[0] == '.' or p[0] == s[0]:
                  # 递归的条件就是只要有在前进就好,这样保证不会出现死循环
                  return self.isMatch(s[1:], p) or self.isMatch(s, p[2:]) if len(p)>=2 and p[1]=='*' else self.isMatch(s[1:], p[1:])
              else:
                  return self.isMatch(s, p[2:]) if len(p)>=2 and p[1]=='*' else False
      
    • 44. 通配符匹配

      class Solution:
          @cache
          def isMatch(self, s: str, p: str) -> bool:
              if not s and not p:
                  return True
              if not p:
                  return False
              if not s:
                  return self.isMatch(s, p[1:]) if p[0] == '*' else False
              if p[0] == '?' or p[0] == s[0]:
                  # 递归的条件就是只要有在前进就好,这样保证不会出现死循环
                  return self.isMatch(s[1:], p[1:])
              else:
                  return self.isMatch(s, p[1:]) or self.isMatch(s[1:], p) if p[0]=='*' else False
      
    • 剑指 Offer II 096. 字符串交织

      class Solution:
          @cache
          def isInterleave(self, s1: str, s2: str, s3: str) -> bool:
              if not s1:
                  return s2 == s3
              if not s2:
                  return s1 == s3
              if len(s1) + len(s2) != len(s3):
                  return False
              if s1[0] == s3[0]:
                  if self.isInterleave(s1[1:], s2[:], s3[1:]):
                      return True
              if s2[0] == s3[0]:
                  return self.isInterleave(s1[:], s2[1:], s3[1:])
              return False
      
    • 38. 外观数列

      class Solution:
          def countAndSay(self, n: int) -> str:
              if n == 1:
                  return "1"
              ret = self.countAndSay(n-1)
              ans = []
              i, j = 0, 0
              while j < len(ret):
                  if ret[j] == ret[i]:
                      j += 1
                  else:
                      ans.extend([str(j-1-i+1), ret[i]])
                      i = j
              ans.extend([str(j-1-i+1), ret[i]])
              return "".join(ans)