代码随想录算法训练营第二十二天 |回溯算法part01
理论基础
回溯法的效率
回溯法的性能如何呢,这里要和大家说清楚了,虽然回溯法很难,很不好理解,但是回溯法并不是什么高效的算法。
因为回溯的本质是穷举,穷举所有可能,然后选出我们想要的答案,如果想让回溯法高效一些,可以加一些剪枝的操作,但也改不了回溯法就是穷举的本质。
那么既然回溯法并不高效为什么还要用它呢?
因为没得选,一些问题能暴力搜出来就不错了,撑死了再剪枝一下,还没有更高效的解法。
此时大家应该好奇了,都什么问题,这么牛逼,只能暴力搜索。
回溯法解决的问题
回溯法,一般可以解决如下几种问题:
-
组合问题:N个数里面按一定规则找出k个数的集合
- 在(1,2,3,4)中找个数为2的集合
-
切割问题:一个字符串按一定规则有几种切割方式
- 给一个字符串,如何切割才能保证字串是回文子串,问有多少种切割方式
-
子集问题:一个N个数的集合里有多少符合条件的子集
- (1,2,3,4)中1 2 3 4 12 13 14 23 24 34 123 124 234把子集都列出来就是子集问题
-
排列问题:N个数按一定规则全排列,有几种排列方式
-
棋盘问题:N皇后,解数独等等
组合是不强调元素顺序的,排列是强调元素顺序。
例如:{1, 2} 和 {2, 1} 在组合上,就是一个集合,因为不强调顺序,而要是排列的话,{1, 2} 和 {2, 1} 就是两个集合了。
记住组合无序,排列有序,就可以了。
如何理解回溯法
回溯是一个递归的过程,那么就一定是有终止的。
回溯法解决的问题都可以抽象为树形结构,是的,我指的是所有回溯法的问题都可以抽象为树形结构!
因为回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度就构成了树的深度。
递归就要有终止条件,所以必然是一棵高度有限的树(N叉树)。
这块可能初学者还不太理解,后面的回溯算法解决的所有题目中,我都会强调这一点并画图举相应的例子,现在有一个印象就行。
回溯法的模板
只有子集问题是在每一个节点收集结果,别的都是在叶子节点收集结果。
def backtracking(参数){
# 终止条件
if 终止条件:
收集结果(比如说组合中的12,13)
return
# 进入单层搜索逻辑
for 集合的元素集:
处理节点(比如说12,13是怎么来的)
# 递归函数
backtracking()
# 回溯操作 撤销处理节点的情况 (1 -> 12 -> 1 -> 13 -> 1 -> 14.....)
return
}
77 组合
需要startIndex来记录下一层递归,搜索的起始位置。
for循环每次从startIndex开始遍历,然后用path保存取到的节点i。
def combine(self, n: int, k: int) -> List[List[int]]:
result = []
self.backtracking(n,k,1,[],result)
return result
def backtracking(self,n,k,startIndex,path,result):
if len(path) == k:
result.append(path[:])
return
for i in range(startIndex,n+1):
path.append(i)
self.backtracking(n,k,i+1,path,result)
path.pop()
注意:为什么不是result.append(path)而是result.append(path[:])
path[:] 用于确保 result 中存储的是 path 的独立副本,而不是引用。
result.append(path) # 直接append引用 相当于result = path path改变的话,result会发生对应的改变
result.append(path[:]) # append的是path的副本
剪枝操作
接下来看一下优化过程如下:
-
已经选择的元素个数:len(path)
-
还需要的元素个数为: k - len(path)
-
在集合n中至多要从该起始位置 : n - ( k - len(path)) + 2,开始遍历
比如说:如果n=4,k=3,4-3+2 = 3,因为是右闭,所以至多从2开始
def combine(self, n: int, k: int) -> List[List[int]]:
result = [] # 存放结果集
self.backtracking(n, k, 1, [], result)
return result
def backtracking(self, n, k, startIndex, path, result):
if len(path) == k:
result.append(path[:])
return
for i in range(startIndex, n - (k - len(path)) + 2): # 优化的地方
path.append(i) # 处理节点
self.backtracking(n, k, i + 1, path, result)
path.pop() # 回溯,撤销处理的节点
216 组合总和III
普通版本:
def combinationSum3(self, k: int, n: int) -> List[List[int]]:
path = []
result = []
self.back(k,n,path,result,1,0)
return result
def back(self,k,n,path,result,start,sum):
if sum == n and len(path) == k:
result.append(path[:])
return
for i in range(start,10):
path.append(i)
sum += i
self.back(k,n,path,result,i+1,sum)
path.pop()
sum -= i
sum += i
self.back(k,n,path,result,i+1,sum)
sum -= i
# 可以简化成
self.back(k,n,path,result,i+1,sum+i)
剪枝操作:
如果没有
if sum > n: # 如果和已经比n大了,那么就直接return
return
虽然有剪枝操作来限制循环的次数,但是没有对 sum > n 的情况进行处理。 这意味着即使 sum 超过了 n,程序仍然会继续进行递归,导致递归深度过大。
会报错 RecursionError: maximum recursion depth exceeded in comparison
def combinationSum3(self, k: int, n: int) -> List[List[int]]:
path = []
result = []
self.back(k,n,path,result,1,0)
return result
def back(self,k,n,path,result,start,sum):
if sum > n: # 如果和已经比n大了,那么就直接return
return
if sum == n and len(path) == k:
result.append(path[:])
return
for i in range(start,10):
path.append(i)
sum += i
self.back(k,n,path,result,i+1,sum)
path.pop()
sum -= i
17 电话号码的字母组合
思路:所有的回溯的题都可以抽象到一颗树型结构上。对于23,首先我们需要选择2,然后2对应3个字母,我们选择第一个字母,然后选择3,3也对应3个字母,我们选择第一个字母,第一个字母完了,选择第二个,以此类推。
所以我们就需要一个Index来控制我们选择的第几个数字,还需要一个letter来控制我们选择的第几个字母,需要把第几个字母加进来。
如下代码所示,digit控制对应的数字,letter控制对应的字母。
def backtracking(digits,index,[]):
if len(digits) == index:
result.append(s)
digit = digits[index]
letter = letterMap[digit]
for i in range(len(letter)):
...........
def letterCombinations(self, digits: str) -> List[str]:
digits = list(digits)
if not digits :
return []
letterMap = [
"", # 0
"", # 1
"abc", # 2
"def", # 3
"ghi", # 4
"jkl", # 5
"mno", # 6
"pqrs", # 7
"tuv", # 8
"wxyz" # 9
]
def backtracking(digits,index,s):
if index == len(digits):
result.append(''.join(s[:]))
return
digit = int(digits[index])
letter = letterMap[digit]
for i in range(len(letter)):
s.append(letter[i])
backtracking(digits,index + 1,s)
s.pop()
path = []
result = []
backtracking(digits,0,[])
return result