基础算法3 - DFS

532 阅读7分钟

DFS的一般使用场景

  • 模版DFS
  • mask举例DFS
  • (一般在tree上)外部空间DFS:用stack将recursion改写成interative way(略)
  • DFS+memo(DP减枝)
  • 模拟流程,寻找所有的解 OR 一个可行的解

PS:

  • backtracking is a more general purpose algorithm.
  • DFS is a specific form of backtracking related to searching graph, tree structures.


DFS vs. BFS

  • BFS:
    • pros:适合解决最短 or 最少问题(即:“最优解”)
    • cons:需要开大量的数组单元用来存储状态
  • DFS:
    • pros:解决是否存在解 or 枚举所有解

    • cons:不适合找“最优解”:如果要找“最优解”,需要遍历所有路径



子集问题

78. 子集(Medium)

image.png

Solu 1:模版DFS

  • 在遍历「选择列表」时,每append一个元素,就把当前path加入res(即:决策树上的每一个node,都是valid solution

image.png

Code 1:

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        def dfs(idx, path):
            # 不需要终止条件:当idx > len(nums)时会自动停止
            for i in range(idx, len(nums)):
                path.append(nums[i])
                res.append(path[:])
                dfs(i + 1, path)
                path.pop()
        
        res = [[]]
        dfs(0, [])
        return res

Solu2:mask举例DFS

  • 生成mask去产生所有的解
  • mask二进制表示的每一位上的数字:1表示取,0表示不取

image.png

Code 2:

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        total_number = 1 << len(nums)
        res = []
        for mask in range(total_number):
            path = []
            for i, num in enumerate(nums):
                if mask & (1 << i):  # 按照mask把每一位取出来
                    path.append(num)
            res.append(path)
        return res


90. 子集 II(Medium)

image.png

Solu:

常见去重方法:

  1. sort
  2. 每个重复数字只处理它的第一次出现
if i > start_idx and nums[i] == nums[i - 1]:
    continue

Code:

class Solution:
    def subsetsWithDup(self, nums: List[int]) -> List[List[int]]:
        def dfs(idx, path):
            for i in range(idx, len(nums)):
                if i > idx and nums[i] == nums[i - 1]:
                    continue
                path.append(nums[i])
                res.append(path[:])
                dfs(i + 1, path)
                path.pop()
        
        nums.sort()
        res = [[]]
        dfs(0, [])
        return res


排列问题

46. 全排列(Medium)

image.png

Solu:

  • 每次都要从头开始取
  • 已经visited过了的就不要了

image.png

Code:

class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        def dfs(path, res):
            if len(path) == len(nums):
                res.append(path[:])
                return
            for i in range(len(nums)):
                if nums[i] in path: # 相当于visited
                    continue
                path.append(nums[i])
                dfs(path, res)
                path.pop()
        
        res = []
        dfs([], res)
        return res


47. 全排列 II(Medium)

image.png

Solu:

  • 去重检查在前:if前一个数字和当前数字一样 && 前一个数字没有被visited,then直接跳过这一轮(when a number has the same value with its previous, we can use this number only if his previous is used
  • 保证只有一次 前一个数字+当前数字的组合 被计算(if当前数字 == 前一个数字)

image.png

Code:

class Solution:
    def permuteUnique(self, nums: List[int]) -> List[List[int]]:
        def dfs(path, res, visited):
            if len(path) == len(nums):
                res.append(path[:])
                return
            for i in range(len(nums)):
                if visited[i] or (i > 0 and nums[i - 1] == nums[i] and not visited[i - 1]):
                    continue
                visited[i] = True
                path.append(nums[i])
                dfs(path, res, visited)
                visited[i] = False
                path.pop()
        
        res = []
        nums.sort()
        dfs([], res, [False] * len(nums))
        return res


组合问题

77. 组合(Medium)

image.png

Solu:

  • 每次从集合中选取元素,可选择的范围随着选择的进行而收缩,调整可选择的范围

image.png

Code:

class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        def dfs(start, path):
            if len(path) == k:
                res.append(path[:])
                return
            for i in range(start, n + 1):
                path.append(i)
                dfs(i + 1, path)
                path.pop()
        
        res = []
        dfs(1, [])
        return res

剪枝优化:

  • iffor循环选择的起始位置之后的元素个数 < 我们需要的元素个数了,then就没必要搜索了
    • 已经选择的元素个数:path.size();
    • 还需要的元素个数为: k - path.size();
    • 在集合n中至多要从该起始位置 : n - (k - path.size()) + 1,开始遍历

image.png

class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        def dfs(start, path):
            if len(path) == k:
                res.append(path[:])
                return
            for i in range(start, n - (k - len(path)) + 2):  # 剪枝
                path.append(i)
                dfs(i + 1, path)
                path.pop()
        
        res = []
        dfs(1, [])
        return res


数独游戏

37. 解数独(Hard)

image.png

Solu:

  • 只要找到一个可行解即可 -> def dfs() -> bool
  • 在每一个需要填数字的位置,暴力尝试所有数字

Code:

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        """
        Do not return anything, modify board in-place instead.
        """
        
        def isValid(i, j, k):
            # row
            for x in range(9):
                if x != j and board[i][x] == k:
                    return False
            # col
            for y in range(9):
                if y != i and board[y][j] == k:
                    return False
            # block
            x0 = (i // 3) * 3
            y0 = (j // 3) * 3
            for x in range(x0, x0 + 3):
                for y in range(y0, y0 + 3):
                    if (x != i or y != j) and board[x][y] == k:
                        return False
            return True
        
        def dfs() -> bool:
            for i in range(9):
                for j in range(9):
                    if board[i][j] == '.':
                        for k in range(1, 10):
                            if isValid(i, j, str(k)):
                                board[i][j] = str(k)
                                if dfs():
                                    return True
                                else:
                                    board[i][j] = '.'  # backtrack
                        return False
            return True
        
        dfs()


N-皇后问题

51. N-皇后(Hard)

image.png

Solu:

  • backtracking尝试所有位置所有解

Code:

class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        def isValid(i, j):
            for row in range(i):
                for col in range(n):
                    # 同列 or 对角线(45度 + 135度)
                    if board[row][col] == 'Q' and (col == j or abs(row - i) == abs(col - j)):
                        return False
            return True
        
        def dfs(row) -> None:
            if row == n:
                res.append([''.join(row) for row in board])
                return
            for col in range(n):
                if isValid(row, col):
                    board[row][col] = 'Q'
                    dfs(row + 1)
                    board[row][col] = '.'  # backtracking
        
        res = []
        board = [['.'] * n for _ in range(n)]
        dfs(0)
        return res


DFS的剪枝 & 优化

常见的剪枝 & 优化方法(PS:不是 dp memoization):

  • sort倒序,task先做大的后做小的
    • stop early,compare early
  • global的reault
    • ie:if求最小值,那么如果计算过程中当前已经 > 当前res了,then 可以直接停止计算
  • 跳过重复的元素
    • 类似于permutation(sort + nums[i-1] == nums[i]
  • 改变搜索思路:backtracking时,for-loop数据规模更小的部分
    • ie:如果A和B有关联,那么根据A.size()和B.size(),size更小的放进for-loop里(即:backtracking时,for-loop数据规模更小的部分)

工作分配

  • 技巧:开一个list记录目前工作的分配情况
  • 模版流程:
    1. 剪枝1:倒序sort
    2. 剪枝2:和global比较
    3. 常规backtracking,for-loop规模较小的数据部分
    4. 剪枝3:去重,同样效果的只取第一次

1723. 完成所有工作的最短时间(Hard)

image.png

Solu: 回溯+剪枝

3重剪枝:

  • prune1:倒叙sort tasks
    • 先分配耗时长的task -> stop early, compare early
  • prune 2:对time去重
    • 如果 time[cur_worker] = time[cur_worker - 1],那么把cur_task分配给当前工人cur_worker 或是 分配给前一个工人cur_worker - 1,本质上没有区别,没必要再算一次
  • prune 3:global result
    • 如果分配给某一个worker的工作时间已经大于了global result,即:当前max(time) >= 目前计算出的最小值ans,那么这个解必然不可能是最优的,没必要算下去了

Code:

class Solution:
    def __init__(self):
        self.ans = sys.maxsize
    
    def minimumTimeRequired(self, jobs: List[int], k: int) -> int:
        def dfs(idx, time: List[int]):
            if idx == len(jobs):
                self.ans = min(self.ans, max(time))
                return
            if max(time) >= self.ans:  # 剪枝3:如果当前的max_time已经大于目前计算出的最小值,则没必要继续计算下去
                return
            for worker in range(k):
                # 剪枝2:如果当前worker和前一个worker的时间一样,那么把当前任务分配给当前worker或前一个worker的效果是一样的,没必要再算一次
                if worker > 0 and time[worker] == time[worker - 1]:
                    continue
                time[worker] += jobs[idx]
                dfs(idx + 1, time)
                time[worker] -= jobs[idx]  # backtracking
        
        jobs.sort(reverse=True)  # 剪枝1:先从耗时长的task开始
        dfs(0, [0] * k)
        return self.ans


1986. 完成任务的最少工作时间段(Medium)

image.png

Solu:回溯+剪枝

sessions用于记录:

  • 当前开了多少个sessions:#session = len(sessions)
  • 记录当前每个开出的session各自已经使用了多少时间

2重剪枝:

  • prune1:倒序sort tasks
  • prune2:当前所需#session超过global result(即:len(sessions) >= self.ans),必然不可能是最优解

Code:

class Solution:
    def __init__(self):
        self.ans = sys.maxsize
        self.sessions = []
    
    def minSessions(self, tasks: List[int], sessionTime: int) -> int:
        def dfs(idx):
            if len(self.sessions) >= self.ans:  # 剪枝2:当前所需#session超过global,必然不可能是最优解
                return
            if idx == len(tasks):
                self.ans = min(self.ans, len(self.sessions))
                return
            # 尝试当前已经开出来的每个session
            for i in range(len(self.sessions)):
                if self.sessions[i] + tasks[idx] <= sessionTime:
                    self.sessions[i] += tasks[idx]
                    dfs(idx + 1)
                    self.sessions[i] -= tasks[idx]  # backtrack
            # 尝试新开一个session
            self.sessions.append(tasks[idx])
            dfs(idx + 1)
            self.sessions.pop()  # backtrack
        
        tasks.sort(reverse=True)  # 剪枝1:先处理耗时长的task
        dfs(0)
        return self.ans


473. 火柴拼正方形(Medium)

image.png

Solu:DFS + 剪枝

  • 剪枝:
    • 倒序sort,易于early termination
    • 去重:如果side[i-1]side[i]的当前长度一样,那么没必要再算一次
  • 类似于「工作分配问题」,记录4边分别还需要多长的火柴摆进来

Code:

class Solution:
    def makesquare(self, matchsticks: List[int]) -> bool:
        if len(matchsticks) < 4 or sum(matchsticks) % 4 != 0:
            return False
        target = sum(matchsticks) // 4
        placement = [target] * 4
        matchsticks.sort(reverse=True)
        
        def dfs(idx, placement) -> bool:
            if any(i < 0 for i in placement):
                return False
            if idx == len(matchsticks):
                return all(i == 0 for i in placement)
            for i in range(4):
                if i > 0 and placement[i] == placement[i - 1]:
                    continue
                placement[i] -= matchsticks[idx]
                if dfs(idx + 1, placement):
                    return True
                placement[i] += matchsticks[idx]  # backtrack
            return False
        
        return dfs(0, placement)


并查集

模版(简化版)

class UnionFind:
    def __init__(self, n):
        self.parent = [i for i in range(n)]
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        self.parent[self.find(x)] = self.find(y)


Reference:

  1. 代码随想录

  2. 古城算法