回溯算法总结

185 阅读17分钟

image.png

我们如同迷宫中的探索者,在前进的道路上可能会遇到困难。 回溯的力量让我们能够重新开始,不断尝试,最终找到通往光明的出口。

理论基础

什么是回溯法

回溯法也可以叫做回溯搜索法,它是一种搜索的方式。

回溯是递归的副产品,只要有递归就会有回溯。

所以以下讲解中,回溯函数也就是递归函数,指的都是一个函数

回溯法的效率

回溯法的性能如何呢,这里要和大家说清楚了,虽然回溯法很难,很不好理解,但是回溯法并不是什么高效的算法

因为回溯的本质是穷举,穷举所有可能,然后选出我们想要的答案,如果想让回溯法高效一些,可以加一些剪枝的操作,但也改不了回溯法就是穷举的本质。

那么既然回溯法并不高效为什么还要用它呢?

因为没得选,一些问题能暴力搜出来就不错了,撑死了再剪枝一下,还没有更高效的解法。

此时大家应该好奇了,都什么问题,这么牛逼,只能暴力搜索。

回溯法解决的问题

回溯法,一般可以解决如下几种问题:

  • 组合问题:N个数里面按一定规则找出k个数的集合
  • 切割问题:一个字符串按一定规则有几种切割方式
  • 子集问题:一个N个数的集合里有多少符合条件的子集
  • 排列问题:N个数按一定规则全排列,有几种排列方式
  • 棋盘问题:N皇后,解数独等等

相信大家看着这些之后会发现,每个问题,都不简单!

另外,会有一些同学可能分不清什么是组合,什么是排列?

组合是不强调元素顺序的,排列是强调元素顺序

例如:{1, 2} 和 {2, 1} 在组合上,就是一个集合,因为不强调顺序,而要是排列的话,{1, 2} 和 {2, 1} 就是两个集合了。

记住组合无序,排列有序,就可以了。

如何理解回溯法

回溯法解决的问题都可以抽象为树形结构,是的,我指的是所有回溯法的问题都可以抽象为树形结构!

因为回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度就构成了树的深度

递归就要有终止条件,所以必然是一棵高度有限的树(N叉树)。

回溯算法模板框架如下:

void backtracking(参数) {
    if (终止条件) {
        存放结果;
        return;
    }

    for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) {
        处理节点;
        backtracking(路径,选择列表); // 递归
        回溯,撤销处理结果
    }
}

组合

第77题. 组合

力扣题目链接

给定两个整数 n 和 k,返回 1 ... n 中所有可能的 k 个数的组合。

示例1

输入: n = 4, k = 2 输出: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4], ]

示例 2:

输入: n = 1, k = 1 输出: [[1]]

image.png

class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        ans=[]
        path=[]
        self.backtracking(n,k,1,path,ans)
        return ans

    def backtracking(self,n,k,startIndex,path,ans):
        if len(path)==k:
            ans.append(path[:])
            return
            
        #for i in range(startIndex,n+1):
        #剪枝优化
        for i in range(startIndex,n-(k-len(path))+2):
            path.append(i)
            self.backtracking(n,k,i+1,path,ans)
            path.pop()

优化过程如下:

  1. 已经选择的元素个数:path.size();
  2. 所需需要的元素个数为: k - path.size();
  3. 列表中剩余元素(n-i) >= 所需需要的元素个数(k - path.size())
  4. 在集合n中至多要从该起始位置 : i <= n - (k - path.size()) + 1,开始遍历

216.组合总和III

力扣题目链接(opens new window)

找出所有相加之和为 n 的 k 个数的组合。组合中只允许含有 1 - 9 的正整数,并且每种组合中不存在重复的数字。

说明:

  • 所有数字都是正整数。
  • 解集不能包含重复的组合。

示例 1: 输入: k = 3, n = 7 输出: [[1,2,4]]

示例 2: 输入: k = 3, n = 9 输出: [[1,2,6], [1,3,5], [2,3,4]]

image.png

class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        ans=[]
        path=[]
        sum=0

        self.backtracking(k,n,1,path,ans,sum)

        return ans

    def backtracking(self,k,n,startIndex,path,ans,sum):
        if sum>n:
            return

        if len(path)==k:
            if sum==n:
                ans.append(path[:])
            return

        for i in range(startIndex,9-(k-len(path))+2):
            sum+=i
            path.append(i)

            self.backtracking(k,n,i+1,path,ans,sum)
            sum-=i
            path.pop()


'''
剪枝:
剩余需要选取的数的个数:k - len(path)
从i到n的数的个数:n - i + 1
为了保证能够选取到足够的数,需要满足:n - i + 1 >= k - len(path)
通过上述不等式,可以推导出:i <= n - (k - len(path)) + 1
进一步简化为:i <= n - k + len(path) + 1'''

17.电话号码的字母组合

力扣题目链接(opens new window)

给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。

给出数字到字母的映射如下(与电话按键相同)。注意 1 不对应任何字母。

17.电话号码的字母组合

示例:

  • 输入:"23"
  • 输出:["ad", "ae", "af", "bd", "be", "bf", "cd", "ce", "cf"].

说明:尽管上面的答案是按字典序排列的,但是你可以任意选择答案输出的顺序。

例如:输入:"23",抽象为树形结构,如图所示:

17. 电话号码的字母组合

class Solution:
    def letterCombinations(self, digits: str) -> List[str]:
        digitMap=[
            '',
            '',
            'abc',
            'def',
            'ghi',
            'jkl',
            'mno',
            'pqrs',
            'tuv',
            'wxyz'
        ]
        ans=[]
        path=[]
        if len(digits)==0:
            return ans

        self.backtracking(digits,path,ans,0,digitMap)
        return ans

    def backtracking(self,digits,path,ans,index,digitMap):
        if len(path)==len(digits):
            ans.append(''.join(path))
            return

        digit=int(digits[index])
        letters=digitMap[digit]

        for letter in letters:
            path.append(letter)
            self.backtracking(digits,path,ans,index+1,digitMap)
            path.pop()

39. 组合总和

力扣题目链接(opens new window)

给定一个无重复元素的数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。

candidates 中的数字可以无限制重复被选取。

说明:

  • 所有数字(包括 target)都是正整数。
  • 解集不能包含重复的组合。

示例 1:

  • 输入:candidates = [2,3,6,7], target = 7,
  • 所求解集为: [ [7], [2,2,3] ]

示例 2:

  • 输入:candidates = [2,3,5], target = 8,
  • 所求解集为: [ [2,2,2,2], [2,3,3], [3,5] ]

image.png

class Solution:
    def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:
        path=[]
        ans=[]
        sum=0
        candidates.sort()
        self.backtracking(candidates,target,0,sum,path,ans)
        return ans

    def backtracking(self,candidates,target,startIndex,sum,path,ans):
        if sum==target:
            ans.append(path[:])

        for i in range(startIndex,len(candidates)):
            if sum+candidates[i]>target:
                break

            sum+=candidates[i]
            path.append(candidates[i])
            self.backtracking(candidates,target,i,sum,path,ans)
#不用i+1了,表示可以重复读取当前的数

            sum-=candidates[i]
            path.pop()

40.组合总和II

力扣题目链接

image.png

class Solution:
    def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]:
        sum=0
        path=[]
        ans=[]
        #排序
        candidates.sort()

        self.backtracking(candidates,target,0,sum,path,ans)

        return ans


    def backtracking(self,candidates,target,startIndex,sum,path,ans):
        if sum==target:
            ans.append(path[:])

        for i in range(startIndex,len(candidates)):
            if i>startIndex and candidates[i]==candidates[i-1]:
                continue

            if sum+candidates[i]>target:
                break

            sum+=candidates[i]

            path.append(candidates[i])

            self.backtracking(candidates,target,i+1,sum,path,ans)

            sum-=candidates[i]

            path.pop()


#在回溯算法中,我们逐个选择数字,并递归地生成组合。为了避免重复组合,我们需要确保:
#对于同一个层级的递归,同一个数字只被选择一次。
#例如,在递归的某一层,我们选择了一个数字 1,那么在这一层中,
#后续的 1 都不应该被选择,否则会导致重复的组合。

分割

131.分割回文串

力扣题目链接(opens new window)

给定一个字符串 s,将 s 分割成一些子串,使每个子串都是回文串。

返回 s 所有可能的分割方案。

示例: 输入: "aab" 输出: [ ["aa","b"], ["a","a","b"] ]

image.png

#abcdef
#切割问题:切割一个a之后,在bcdef中再去切割第二段,切割b之后在cdef中再切割第三段.....。
class Solution:
    def partition(self, s: str) -> List[List[str]]:
        path=[]
        ans=[]
        self.backtracking(s,0,path,ans)
        return ans

    def backtracking(self,s,startIndex,path,ans):
        if startIndex==len(s):
            ans.append(path[:])
            return

        for i in range(startIndex,len(s)):
            if self.is_palindrome(s,startIndex,i):
                path.append(s[startIndex:i+1])
                self.backtracking(s,i+1,path,ans)

                path.pop()
    

    def is_palindrome(self,s,start,end):#双指针
        i=start
        j=end

        while i<=j:
            if s[i]!=s[j]:
                return False

            i+=1
            j-=1

        return True

93.复原IP地址

力扣题目链接(opens new window)

给定一个只包含数字的字符串,复原它并返回所有可能的 IP 地址格式。

有效的 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 '.' 分隔。

例如:"0.1.2.201" 和 "192.168.1.1" 是 有效的 IP 地址,但是 "0.011.255.245"、"192.168.1.312" 和 "192.168@1.1" 是 无效的 IP 地址。

示例 1:

  • 输入:s = "25525511135"
  • 输出:["255.255.11.135","255.255.111.35"]

示例 2:

  • 输入:s = "0000"
  • 输出:["0.0.0.0"]

示例 3:

  • 输入:s = "1111"
  • 输出:["1.1.1.1"]

示例 4:

  • 输入:s = "010010"
  • 输出:["0.10.0.10","0.100.1.0"]

示例 5:

  • 输入:s = "101023"
  • 输出:["1.0.10.23","1.0.102.3","10.1.0.23","10.10.2.3","101.0.2.3"]

提示:

  • 0 <= s.length <= 3000
  • s 仅由数字组成

image.png

class Solution:
    def restoreIpAddresses(self, s: str) -> List[str]:
        path=[]
        ans=[]
        self.backtracking(s,0,path,ans)
        return ans


    def backtracking(self,s,startIndex,path,ans):
        if len(path)==4 and startIndex==len(s):
            ans.append('.'.join(path[:]))
            return

        if len(path)>4:
            return

        for i in range(startIndex,min(startIndex+3,len(s))):
            if self.isValid(s,startIndex,i):
                subString=s[startIndex:i+1]
                path.append(subString)
                self.backtracking(s,i+1,path,ans)
                path.pop()
    

    def isValid(self,s,start,end):
        if start>end:
            return False

        if s[start]=='0' and start!=end:
            return False

        num=int(s[start:end+1])

        return 0 <= num <= 255

子集

78.子集

力扣题目链接(opens new window)

给定一组不含重复元素的整数数组 nums,返回该数组所有可能的子集(幂集)。

说明:解集不能包含重复的子集。

示例: 输入: nums = [1,2,3] 输出: [ [3],   [1],   [2],   [1,2,3],   [1,3],   [2,3],   [1,2],   [] ]

image.png

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        path=[]
        ans=[]
        self.backtracking(nums,0,path,ans)
        return ans

    def backtracking(self,nums,startIndex,path,ans):
        ans.append(path[:])

        if startIndex>=len(nums):
            return

        for i in range(startIndex,len(nums)):

            path.append(nums[i])
            self.backtracking(nums,i+1,path,ans)

            path.pop()

90.子集II

力扣题目链接(opens new window)

给定一个可能包含重复元素的整数数组 nums,返回该数组所有可能的子集(幂集)。

说明:解集不能包含重复的子集。

示例:

  • 输入: [1,2,2]
  • 输出: [ [2], [1], [1,2,2], [2,2], [1,2], [] ]
class Solution:
    def subsetsWithDup(self, nums: List[int]) -> List[List[int]]:

        path=[]
        ans=[]
        nums.sort()#排序
        self.backtracking(nums,0,path,ans)
        return ans

    def backtracking(self,nums,startIndex,path,ans):
        ans.append(path[:])

        if startIndex>=len(nums):
            return

        for i in range(startIndex,len(nums)):
            if i>startIndex and nums[i]==nums[i-1]:
                continue

            path.append(nums[i])
            self.backtracking(nums,i+1,path,ans)

            path.pop()

491.递增子序列

力扣题目链接(opens new window)

给定一个整型数组, 你的任务是找到所有该数组的递增子序列,递增子序列的长度至少是2。

示例:

  • 输入: [4, 6, 7, 7]
  • 输出: [[4, 6], [4, 7], [4, 6, 7], [4, 6, 7, 7], [6, 7], [6, 7, 7], [7,7], [4,7,7]]

说明:

  • 给定数组的长度不会超过15。
  • 数组中的整数范围是 [-100,100]。
  • 给定数组中可能包含重复数字,相等的数字应该被视为递增的一种情况

image.png

class Solution:
    def findSubsequences(self, nums: List[int]) -> List[List[int]]:
        path=[]
        ans=[]
        self.backtracking(nums,0,path,ans)
        return ans

    def backtracking(self,nums,startIndex,path,ans):
        if len(path)>1:
            ans.append(path[:])
            # 注意这里不要加return,要取树上的节点

        uset=set()

        for i in range(startIndex,len(nums)):

            if (len(path)>0 and nums[i]<path[-1]) or nums[i] in uset:
                continue
            uset.add(nums[i])# 记录这个元素在本层用过了,本层后面不能再用了
            path.append(nums[i])

            self.backtracking(nums,i+1,path,ans)

            path.pop()


在递归回溯的过程中,`path` 会不断添加和移除数字,以生成所有可能的递增子序列。
然而,如果数组 `nums` 中存在重复的数字,可能会导致生成重复的子序列。
例如,对于数组 `[4, 6, 7, 7]`,如果不使用 `uset`,可能会生成两个相同的子序列 `[4, 6, 7]`,因为数组中有两个 `7`。
通过使用 `uset`,可以记录在当前递归层级中已经使用过的数字,
避免在当前层级中重复选择相同的数字,从而避免生成重复的子序列。

排列

46.全排列

力扣题目链接(opens new window)

给定一个 没有重复 数字的序列,返回其所有可能的全排列。

示例:

  • 输入: [1,2,3]
  • 输出: [ [1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2], [3,2,1]]
class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        path=[]
        ans=[]
        self.backtracking(nums,0,path,ans)
        return ans

    def backtracking(self,nums,startIndex,path,ans):
        if len(nums)==len(path):
            ans.append(path[:])
            return

        for i in range(0,len(nums)):
            if nums[i] in path:
                continue

            path.append(nums[i])
            self.backtracking(nums,i+1,path,ans)
            path.pop()

大家此时可以感受出排列问题的不同:

  • 每层都是从0开始搜索而不是startIndex
  • 需要used数组记录path里都放了哪些元素了

47.全排列 II

力扣题目链接(opens new window)

给定一个可包含重复数字的序列 nums ,按任意顺序 返回所有不重复的全排列。

示例 1:

  • 输入:nums = [1,1,2]
  • 输出: [[1,1,2], [1,2,1], [2,1,1]]

示例 2:

  • 输入:nums = [1,2,3]
  • 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]

提示:

  • 1 <= nums.length <= 8
  • -10 <= nums[i] <= 10
class Solution:
    def permuteUnique(self, nums: List[int]) -> List[List[int]]:

        path=[]
        ans=[]
        nums.sort()
        
        used_list=[0]*len(nums)
        
        self.backtracking(nums,0,path,ans,used_list)
        return ans

    def backtracking(self,nums,startIndex,path,ans,used_list):
        if len(nums)==len(path):
            ans.append(path[:])
            return

        for i in range(0,len(nums)):
            if (i>0 and nums[i]==nums[i-1] and used_list[i-1]==1) or used_list[i]==1:
                continue

            used_list[i]=1
            path.append(nums[i])
            self.backtracking(nums,i+1,path,ans,used_list)
            path.pop()
            used_list[i]=0        

棋盘问题

51. N皇后

力扣题目链接(opens new window)

n 皇后问题 研究的是如何将 n 个皇后放置在 n×n 的棋盘上,并且使皇后彼此之间不能相互攻击。

给你一个整数 n ,返回所有不同的 n 皇后问题 的解决方案。

每一种解法包含一个不同的 n 皇后问题 的棋子放置方案,该方案中 'Q' 和 '.' 分别代表了皇后和空位。

示例 1:

  • 输入:n = 4
  • 输出:[[".Q..","...Q","Q...","..Q."],["..Q.","Q...","...Q",".Q.."]]
  • 解释:如上图所示,4 皇后问题存在两个不同的解法。

示例 2:

  • 输入:n = 1
  • 输出:[["Q"]]

首先来看一下皇后们的约束条件:

  1. 不能同行
  2. 不能同列
  3. 不能同斜线
class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        '''n=4
        chessBoard=[
            '....',
            '....',
            '....',
            '....'
        ]      
        ''' 
        ans=[]
        chessBoard=['.' * n for _ in range(n)]
        self.backtracking(n,0,chessBoard,ans)
        return ans

    def backtracking(self,n,row,chessBoard,ans):
        if row==n:
            ans.append(chessBoard[:])
            return

        for col in range(n):
            if self.isValid(row,col,chessBoard):
                chessBoard[row]=chessBoard[row][:col]+'Q'+chessBoard[row][col+1:]
                self.backtracking(n,row+1,chessBoard,ans)
                chessBoard[row]=chessBoard[row][:col]+'.'+chessBoard[row][col+1:]


    def isValid(self,row,col,chessBoard):
        for i in range(row):#列
            if chessBoard[i][col]=='Q':
                return False

        i=row-1
        j=col-1
        while i>=0 and j>=0:#左上45
            if chessBoard[i][j]=='Q':
                return False

            i-=1
            j-=1

        i=row-1
        j=col+1
        while i>=0 and j<len(chessBoard):#右上45
            if chessBoard[i][j]=='Q':
                return False
            i-=1
            j+=1

        return True

37. 解数独

力扣题目链接(opens new window)

编写一个程序,通过填充空格来解决数独问题。

一个数独的解法需遵循如下规则: 数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现一次。 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。 空白格用 '.' 表示。

解数独

一个数独。

解数独

答案被标成红色。

提示:

  • 给定的数独序列只包含数字 1-9 和字符 '.' 。
  • 你可以假设给定的数独只有唯一解。
  • 给定数独永远是 9x9 形式的。
class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        """
        Do not return anything, modify board in-place instead.
        """
        row_used = [set() for _ in range(9)]
        col_used = [set() for _ in range(9)]
        box_used = [set() for _ in range(9)]
        for row in range(9):
            for col in range(9):
                num = board[row][col]
                if num == ".":
                    continue
                row_used[row].add(num)
                col_used[col].add(num)
                box_used[(row // 3) * 3 + col // 3].add(num)
        self.backtracking(0, 0, board, row_used, col_used, box_used)

    def backtracking(
        self,
        row: int,
        col: int,
        board: List[List[str]],
        row_used: List[List[int]],
        col_used: List[List[int]],
        box_used: List[List[int]],
    ) -> bool:
        if row == 9:
            return True

        next_row, next_col = (row, col + 1) if col < 8 else (row + 1, 0)
        if board[row][col] != ".":
            return self.backtracking(
                next_row, next_col, board, row_used, col_used, box_used
            )

        for num in map(str, range(1, 10)):
            if (
                num not in row_used[row]
                and num not in col_used[col]
                and num not in box_used[(row // 3) * 3 + col // 3]
            ):
                board[row][col] = num
                row_used[row].add(num)
                col_used[col].add(num)
                box_used[(row // 3) * 3 + col // 3].add(num)
                if self.backtracking(
                    next_row, next_col, board, row_used, col_used, box_used
                ):
                    return True
                board[row][col] = "."
                row_used[row].remove(num)
                col_used[col].remove(num)
                box_used[(row // 3) * 3 + col // 3].remove(num)
        return False

1. 函数 solveSudoku

  • 输入参数:

    • board: 一个9x9的二维列表,表示数独棋盘,其中.表示空格,需要填充的数字。
  • 功能:

    • 初始化行、列和3x3宫格中已使用的数字集合。
    • 调用 backtracking 函数进行回溯搜索,尝试填充数独棋盘。

2. 初始化部分

  • row_used: 一个包含9个集合的列表,每个集合存储对应行中已使用的数字。
  • col_used: 一个包含9个集合的列表,每个集合存储对应列中已使用的数字。
  • box_used: 一个包含9个集合的列表,每个集合存储对应3x3宫格中已使用的数字。
  • 遍历棋盘,将已有的数字加入到对应的行、列和宫格集合中。

3. 函数 backtracking

  • 输入参数:

    • row: 当前行索引。
    • col: 当前列索引。
    • board: 数独棋盘。
    • row_used: 行中已使用的数字集合。
    • col_used: 列中已使用的数字集合。
    • box_used: 3x3宫格中已使用的数字集合。
  • 功能:

    • 使用回溯算法尝试填充数独棋盘。
    • 如果当前行索引达到9,表示数独已成功填充,返回 True
    • 如果当前格子已有数字(不是.),跳过当前格子,递归调用下一个格子。
    • 如果当前格子为空(.),尝试填入1到9的数字,检查是否满足数独规则(不在当前行、列和宫格中出现)。
    • 如果填入的数字满足规则,递归调用下一个格子,如果成功返回 True
    • 如果递归调用失败,回溯:将当前格子恢复为.,并从集合中移除该数字,尝试下一个数字。

在数独问题中,棋盘被划分为9个3x3的宫格,每个宫格内的数字1到9不能重复。box_used是一个列表,用于存储每个3x3宫格中已经使用的数字。box_used[(row // 3) * 3 + col // 3]的目的是确定当前格子所在的3x3宫格的索引。

解释 box_used[(row // 3) * 3 + col // 3]

  • row // 3: 计算当前行所在的3x3宫格的行索引。例如,行0、1、2都属于第0行的宫格,行3、4、5属于第1行的宫格,以此类推。
  • col // 3: 计算当前列所在的3x3宫格的列索引。例如,列0、1、2都属于第0列的宫格,列3、4、5属于第1列的宫格,以此类推。
  • (row // 3) * 3 + col // 3: 将行索引和列索引转换为一个唯一的索引,表示当前格子所在的3x3宫格。这个索引的范围是0到8,对应9个3x3宫格。

举例说明

假设当前格子的行索引是4,列索引是5:

  • row // 3 = 4 // 3 = 1
  • col // 3 = 5 // 3 = 1
  • (row // 3) * 3 + col // 3 = 1 * 3 + 1 = 4

所以,当前格子所在的3x3宫格的索引是4,即第5个宫格(从0开始计数)。

宫格索引的可视化

数独棋盘的9个3x3宫格可以这样划分:

0 1 2
3 4 5
6 7 8

每个宫格的索引计算公式为 (row // 3) * 3 + col // 3,这样可以确保每个宫格都有一个唯一的索引。

代码中的作用

在代码中,box_used[(row // 3) * 3 + col // 3].add(num)的作用是将当前数字num加入到当前格子所在的3x3宫格的已使用数字集合中。这样可以确保在后续的回溯过程中,不会在同一个宫格中重复使用相同的数字。