【回溯】算法思想,附两道道面试手撕题

6 阅读6分钟

回溯算法

回溯算法是一种通过深度优先搜索(DFS)来解决问题的算法策略。它的核心思想是在搜索过程中,逐步构建问题的解,当发现当前路径不可能产生正确的完整解时,就回溯到上一步,尝试其他可能的路径。

这种算法特别适用于解决组合问题、排列问题、划分问题等。

回溯算法的核心

回溯算法的核心在于递归和回溯。递归用于在解空间树中深入探索,而回溯则是在发现当前路径不可行时撤销上一步或多步的决策,回到之前的节点继续探索。

回溯算法的模板

回溯算法的递归函数通常遵循一个固定的模板,如下所示:

void backtracking(参数) {
    if (终止条件) {
        存放结果;
        return;
    }
    for (选择本层集合中元素) {
        处理节点;
        backtracking();
        回溯,撤销处理结果;
    }
}

这个模板强调了递归函数的几个关键点:

  1. 递归函数通常没有返回值。
  2. 先写终止条件,这是收集结果的关键时刻。
  3. 单层搜索使用for循环,处理集合中的元素。
  4. 递归调用自身,深入探索解空间。
  5. 回溯操作,手动撤销之前的处理。

从树的遍历理解回溯

理解回溯算法的一个好方法是从树的遍历开始。在二叉树中,前序遍历的代码如下:

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def treeDFS(root):
    if root is None:
        return
    print(root.val)
    treeDFS(root.left)
    treeDFS(root.right)

# 示例
# 构建一个简单的二叉树
#       1
#      / \
#     2   3
#    / \
#   4   5
root = TreeNode(1)
root.left = TreeNode(2, TreeNode(4), TreeNode(5))
root.right = TreeNode(3)

# 执行前序遍历
treeDFS(root)

对于n叉树,遍历代码变为:

class TreeNode:
    def __init__(self, val=0, children=None):
        self.val = val
        self.children = children if children is not None else []

def treeDFS(root):
    if root is None:
        return
    print(root.val)
    for node in root.children:
        treeDFS(node)

# 示例使用
# 构建一个简单的 n 叉树
#       1
#      / | \
#     2  3  4
#    / \
#   5   6
root = TreeNode(1, [TreeNode(2, [TreeNode(5), TreeNode(6)]), TreeNode(3), TreeNode(4)])

# 执行前序遍历
treeDFS(root)

因为是n叉树,所以没办法再用leftright表示分支了,这里用了一个List

注意:可以对比下回溯模版,是否有相似之处。

算法题

第 K 排列

描述

给定参数n,从1到n会有n个整数:1,2,3,…,n,这n个数字共有n!种排列。

按大小顺序升序列出所有排列的情况,并一一标记,

当n=3时,所有排列如下:

“123” “132” “213” “231” “312” “321”

给定n和k,返回第k个排列。

输入描述

  • 输入两行,第一行为n,第二行为k,
  • 给定n的范围是[1,9],给定k的范围是[1,n!]。

输出描述

输出排在第k位置的数字。

题解

这题求序列的可能排列,最后排序返回第 K 号序列值。

我们在理解回溯思想后,套模版进行求解就 OK。

def generate_permutations(nums, current, result, k):
    """
    递归生成所有排列,并存储在 result 列表中。

    :param nums: 剩余的数字列表。
    :param current: 当前生成的排列。
    :param result: 存储所有排列的结果列表。
    :param k: 需要找到的第 k 个排列。
    """
    if len(nums) == 0:
        # 如果没有剩余数字,说明找到了一个完整的排列。
        result.append(current)
        if len(result) == k:
            # 如果已经找到 k 个排列,提前结束。
            return True
        return False

    for i in range(len(nums)):
        # 选择当前数字。
        num = nums[i]
        # 生成不包含当前数字的新列表。
        new_nums = nums[:i] + nums[i+1:]
        # 递归生成排列。
        if generate_permutations(new_nums, current + str(num), result, k):
            # 如果已经找到 k 个排列,提前结束。
            return True

# 读取用户输入。
n = int(input())
k = int(input())

# 特殊情况处理。
if n == 1:
    print("1")
    exit()

# 初始化数字列表和结果列表。
nums = [i + 1 for i in range(n)]
result = []

# 生成排列。
generate_permutations(nums, "", result, k)

# 对结果列表进行排序(虽然在这个特定问题中不需要,因为生成的顺序已经是有序的)。
result.sort()

# 输出第 k 个排列。
print(result[k-1])

字符串拼接

描述

给定 M(0 < M ≤ 30)个字符(a-z),从中取出任意字符(每个字符只能用一次)拼接成长度为 N(0 < N ≤ 5)的字符串,

要求相同的字符不能相邻,计算出给定的字符列表能拼接出多少种满足条件的字符串,

输入非法或者无法拼接出满足条件的字符串则返回0。

输入描述

给定的字符列表和结果字符串长度,中间使用空格(" ")拼接

输出描述

满足条件的字符串个数

题解

我们定义一个函数generateDistinctStrings,它接受以下参数:

  • s:一个包含可用字符的集合。
  • length:我们想要生成的字符串的目标长度。
  • current:当前正在构建的字符串。
  • result:一个集合,用来存储所有生成的唯一字符串。
  • used:一个布尔数组,用来跟踪s中的每个字符是否已经被用于构建current

函数的逻辑如下:

  1. 检查current的长度是否已经达到length。如果是,将current添加到result中,并返回。
  2. 如果current的长度还未达到length,遍历s中的每个字符c
  3. 对于每个字符c,检查:
    • c是否已经被使用(即used数组中对应的值为true)。
    • c是否与current的最后一个字符相同。 如果任一条件为真,则跳过当前字符,继续检查下一个字符。
  4. 如果c未被使用且与current的最后一个字符不同,则:
    • c添加到current的末尾。
    • 标记c为已使用。
    • 递归调用generateDistinctStrings,继续构建下一个字符。
  5. 递归调用返回后,取消对c的使用标记,以便它可以在后续的递归中被再次使用。

伪代码实现

函数 generateDistinctStrings(s, length, current, result, used)
    如果 current的长度 等于 length
        将 current 添加到 result
        返回
    对于 s中的每一个字符 c
        如果 used[c] 或者 c与current的最后一个字符相同
            继续下一次循环
        标记 used[c] 为 true
        generateDistinctStrings(s, length, current + c, result, used)
        标记 used[c] 为 false
from collections import Counter

# 定义深度优先搜索函数
def dfs(cur_s):
    global ans, n, cnts
    # 如果当前字符串长度等于目标长度n,增加答案计数
    if len(cur_s) == n:
        ans += 1
        return

    # 遍历字符计数器中的每个字符及其计数
    for k, v in cnts.items():
        # 如果字符计数为0或者当前字符串不为空且最后一个字符与当前字符相同,则跳过
        if v == 0 or (cur_s and k == cur_s[-1]):
            continue
        # 选择一个字符并减少其计数,然后递归调用dfs
        cnts[k] -= 1
        dfs(cur_s + k)
        # 回溯,恢复字符计数
        cnts[k] += 1

# 读取输入的字符串和目标长度
s, n = input().strip().split()
n = int(n)

# 使用Counter统计输入字符串中每个字符的出现次数
cnts = Counter(s)

# 初始化答案计数器
ans = 0

# 检查所有字符是否都是小写字母
if all('a' <= k <= 'z' for k in cnts):
    # 如果是,从空字符串开始调用dfs
    dfs('')
else:
    # 如果不是,直接输出0
    print(0)
    exit()

# 输出不同字符串的数量
print(ans)