代码随想录算法训练营第二十二天 |回溯算法part01

134 阅读6分钟

代码随想录算法训练营第二十二天 |回溯算法part01

理论基础

回溯法的效率

回溯法的性能如何呢,这里要和大家说清楚了,虽然回溯法很难,很不好理解,但是回溯法并不是什么高效的算法

因为回溯的本质是穷举,穷举所有可能,然后选出我们想要的答案,如果想让回溯法高效一些,可以加一些剪枝的操作,但也改不了回溯法就是穷举的本质。

那么既然回溯法并不高效为什么还要用它呢?

因为没得选,一些问题能暴力搜出来就不错了,撑死了再剪枝一下,还没有更高效的解法。

此时大家应该好奇了,都什么问题,这么牛逼,只能暴力搜索。

回溯法解决的问题

回溯法,一般可以解决如下几种问题:

  • 组合问题:N个数里面按一定规则找出k个数的集合

    • 在(1,2,3,4)中找个数为2的集合
  • 切割问题:一个字符串按一定规则有几种切割方式

    • 给一个字符串,如何切割才能保证字串是回文子串,问有多少种切割方式
  • 子集问题:一个N个数的集合里有多少符合条件的子集

    • (1,2,3,4)中1 2 3 4 12 13 14 23 24 34 123 124 234把子集都列出来就是子集问题
  • 排列问题:N个数按一定规则全排列,有几种排列方式

  • 棋盘问题:N皇后,解数独等等

组合是不强调元素顺序的,排列是强调元素顺序

例如:{1, 2} 和 {2, 1} 在组合上,就是一个集合,因为不强调顺序,而要是排列的话,{1, 2} 和 {2, 1} 就是两个集合了。

记住组合无序,排列有序,就可以了。

如何理解回溯法

回溯是一个递归的过程,那么就一定是有终止的。

回溯法解决的问题都可以抽象为树形结构,是的,我指的是所有回溯法的问题都可以抽象为树形结构!

因为回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度就构成了树的深度

回溯算法理论基础

递归就要有终止条件,所以必然是一棵高度有限的树(N叉树)。

这块可能初学者还不太理解,后面的回溯算法解决的所有题目中,我都会强调这一点并画图举相应的例子,现在有一个印象就行。

回溯法的模板

只有子集问题是在每一个节点收集结果,别的都是在叶子节点收集结果。

def backtracking(参数){
    # 终止条件
    if 终止条件:
        收集结果(比如说组合中的12,13)
        return
    # 进入单层搜索逻辑
    for 集合的元素集:
        处理节点(比如说1213是怎么来的)
        # 递归函数
        backtracking()
        # 回溯操作 撤销处理节点的情况 (1 -> 12 -> 1 -> 13 -> 1 -> 14.....)
    return
}

77 组合

image.png

77.组合3

需要startIndex来记录下一层递归,搜索的起始位置。

for循环每次从startIndex开始遍历,然后用path保存取到的节点i。

def combine(self, n: int, k: int) -> List[List[int]]:
    result = []
    self.backtracking(n,k,1,[],result)
    return result
def backtracking(self,n,k,startIndex,path,result):
    if len(path) == k:
        result.append(path[:])
        return
    for i in range(startIndex,n+1):
        path.append(i)
        self.backtracking(n,k,i+1,path,result)
        path.pop()

注意:为什么不是result.append(path)而是result.append(path[:])

path[:] 用于确保 result 中存储的是 path 的独立副本,而不是引用。

result.append(path)  # 直接append引用 相当于result = path path改变的话,result会发生对应的改变
result.append(path[:])  # append的是path的副本

剪枝操作

接下来看一下优化过程如下:

  1. 已经选择的元素个数:len(path)

  2. 还需要的元素个数为: k - len(path)

  3. 在集合n中至多要从该起始位置 : n - ( k - len(path)) + 2,开始遍历

    比如说:如果n=4,k=3,4-3+2 = 3,因为是右闭,所以至多从2开始

def combine(self, n: int, k: int) -> List[List[int]]:
    result = []  # 存放结果集
    self.backtracking(n, k, 1, [], result)
    return result
def backtracking(self, n, k, startIndex, path, result):
    if len(path) == k:
        result.append(path[:])
        return
    for i in range(startIndex, n - (k - len(path)) + 2):  # 优化的地方
        path.append(i)  # 处理节点
        self.backtracking(n, k, i + 1, path, result)
        path.pop()  # 回溯,撤销处理的节点

216 组合总和III

image.png

普通版本:

def combinationSum3(self, k: int, n: int) -> List[List[int]]:
    path = []
    result = []
    self.back(k,n,path,result,1,0)
    return result
def back(self,k,n,path,result,start,sum):
    if sum == n and len(path) == k:
        result.append(path[:])
        return
    for i in range(start,10):
        path.append(i)
        sum += i
        self.back(k,n,path,result,i+1,sum)
        path.pop()
        sum -= i
sum += i
self.back(k,n,path,result,i+1,sum)
sum -= i
# 可以简化成
self.back(k,n,path,result,i+1,sum+i)

剪枝操作:

如果没有

if sum > n: # 如果和已经比n大了,那么就直接return
    return

虽然有剪枝操作来限制循环的次数,但是没有对 sum > n 的情况进行处理。 这意味着即使 sum 超过了 n,程序仍然会继续进行递归,导致递归深度过大。

会报错 RecursionError: maximum recursion depth exceeded in comparison

    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        path = []
        result = []
        self.back(k,n,path,result,1,0)
        return result
    def back(self,k,n,path,result,start,sum):
        if sum > n: # 如果和已经比n大了,那么就直接return
            return
        if sum == n and len(path) == k:
            result.append(path[:])
            return
        for i in range(start,10):
            path.append(i)
            sum += i
            self.back(k,n,path,result,i+1,sum)
            path.pop()
            sum -= i

17 电话号码的字母组合

image.png

思路:所有的回溯的题都可以抽象到一颗树型结构上。对于23,首先我们需要选择2,然后2对应3个字母,我们选择第一个字母,然后选择3,3也对应3个字母,我们选择第一个字母,第一个字母完了,选择第二个,以此类推。

所以我们就需要一个Index来控制我们选择的第几个数字,还需要一个letter来控制我们选择的第几个字母,需要把第几个字母加进来。

如下代码所示,digit控制对应的数字,letter控制对应的字母。

    def backtracking(digits,index,[]):
        if len(digits) == index:
            result.append(s)
        digit = digits[index]
        letter = letterMap[digit]
        for i in range(len(letter)):
            ...........
    def letterCombinations(self, digits: str) -> List[str]:
        digits = list(digits)
        if not digits :
            return []
        letterMap = [
                    "",     # 0
                    "",     # 1
                    "abc",  # 2
                    "def",  # 3
                    "ghi",  # 4
                    "jkl",  # 5
                    "mno",  # 6
                    "pqrs", # 7
                    "tuv",  # 8
                    "wxyz"  # 9
                ]
        def backtracking(digits,index,s):
            if index == len(digits):
                result.append(''.join(s[:]))
                return
            digit = int(digits[index])
            letter = letterMap[digit]
            for i in range(len(letter)):
                s.append(letter[i])
                backtracking(digits,index + 1,s)
                s.pop()
    ​
        path = []
        result = []
        backtracking(digits,0,[])
        return result
    ​