初识回溯,感觉特别有意思,这里简单介绍下数组的子集排列组合三种种情况。自己写的代码,可能不够规范,供初学理解:)
1. 求子集
这类题上来即使没什么思路,也应该先把path和res给初始化,自定义backtrack函数如果不知道要传什么参数没关系,先把跟踪位置的idx给写上,path和res这些如果作为全局变量可以不用写到回溯函数里。甚至这两题的参数可以只有idx就足够了,一般都是跟踪索引。
该题递归开始的时候先把当前元素加入res。
def subsets(nums):
def backtrack(path, idx):
res.append(list(path)) # 先把当前的加入 一定要先list再加入!不然加的是引用啊
for i in range(idx,len(nums)):
path.append(nums[i])
backtrack(path, i+1)
path.pop()
path=[]
res = []
backtrack(path,0) #该函数的作用是:更新res 无返回
return res #返回更新完的res
# 测试
subsets([1,2,3])
>
[[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]
2. 求长度为K的所有组合(不重复)
这两个例子代码几乎长得一样,只是这个限制了回溯树的深度。当前path(也就是回溯树的某个分支)的深度达到K就返回,停止继续往下递归。返回之前更新res结果列表。
def combine(n, k):
def backtrack(path, idx):
if len(path)==k:
res.append(list(path)) #符合题意的解,return的意思是不能再往深了走
return
for i in range(idx,len(nums)):
path.append(nums[i])
backtrack(path, i+1)
path.pop()
nums = [i for i in range(1,n+1)]
path = [] #每个长度都是k
res = [] #添加path的时候一定要先列表化
backtrack(path, 0)
return res
# 测试
combine(4,2)
>
[[1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]]
3. 求全排列(不重复)
该类问题不需要跟踪idx, 而且不是单向的遍历(前两种都是从前往后遍历),所以用continue来排除已存在于path里的元素!
def permute(nums):
def backtrack(path):
if len(path) == n:
res.append(list(path))
return
for i in range(n):
if nums[i] in path: #当前加入过path的元素,都要排除,在剩下的里面选
continue
path.append(nums[i])
backtrack(path) # 这类不需要跟踪idx!
path.pop()
n = len(nums)
path = []
res = []
backtrack(path)
return res
4.求全排列(含有重复元素)
重点就是
not used[i-1]时跳过,才能保证按原来的顺序输出。122' => [1,2,2']
def permute(nums):
def backtrack(path):
if len(path) == n:
res.append(list(path))
return
for i in range(n):
if used[i] or (i>0 and nums[i]== nums[i-1] and not used[i-1]): #一定是!used[i-1]才能保证按原序
continue
path.append(nums[i])
used[i] = True
backtrack(path)
path.pop()
used[i] = False
n = len(nums)
path,res = [],[]
nums = sorted(nums)
used = [False for i in range(n)] #用过标记为True
backtrack(path)
return res