问题描述
小U决定在一个 m×nm×n 的地图上行走。地图中的每个位置都有一个高度,表示地形的高低。小U只能在满足以下条件的情况下移动:
- 只能上坡或者下坡,不能走到高度相同的点。
- 移动时必须交替进行:上坡后必须下坡,下坡后必须上坡,不能连续上坡或下坡。
- 每个位置只能经过一次,不能重复行走。
任务是帮助小U找到他在地图上可以移动的最大次数,即在符合所有条件的前提下,小U能走过的最大连续位置数量。
分析
这是一个典型的深度优先搜索(DFS)题目,要求我们在一个二维地图上找到一条符合特定规则的最长路径。为了求解这个问题,可以使用递归的方式(即深度优先搜索),结合回溯(Backtracking)来遍历所有符合条件的路径。
思路简述
-
状态表示:
- 每个位置
(x, y)有一个高度a[x][y],并且小U的行走顺序需要满足交替上坡和下坡的规则。 - 使用一个
vis数组来标记每个位置是否已经访问过,避免重复走过同一个位置。
- 每个位置
-
状态转移:
- 每次从当前位置
(x, y)出发,检查其四个相邻的方向(上下左右),选择符合条件的方向继续行走。 - 上坡:如果当前状态是上坡(k = 0),则下一步必须是下坡。
- 下坡:如果当前状态是下坡(k = 1),则下一步必须是上坡。
- 每次从当前位置
-
DFS与回溯:
- 深度优先搜索的关键是递归调用,每次访问一个新位置后,进行标记(
vis[x][y] = 1),然后继续搜索。 - 在搜索完成后,需要进行回溯,即恢复当前位置的状态(
vis[x][y] = 0),以便探索其他路径。
- 深度优先搜索的关键是递归调用,每次访问一个新位置后,进行标记(
-
优化:
- 对于每个位置,我们可以从两种状态(上坡和下坡)分别开始搜索,最终得到最大路径长度。
DFS模板框架
def dfs(当前状态, 一系列其他的状态量):
if (当前状态 == 目的状态):
# 达到目标状态,进行处理
...
for (搜索下一状态):
if (当前状态合法):
vis[当前位置] = 1 # 标记当前位置已访问
dfs(新的状态) # 递归调用
vis[当前位置] = 0 # 回溯,恢复访问状态
代码
def solution(m: int, n: int, a: list):
# 确保输入矩阵的行数为 m 且每一行的列数为 n
assert m == len(a) and all(len(v) == n for v in a)
# 初始化访问标记矩阵,记录某个位置是否已经被访问过
vis = [[0 for _ in range(n)] for _ in range(m)]
# 深度优先搜索 (DFS) 函数,参数:
# x, y: 当前坐标
# k: 当前状态,k=0 表示要求递增路径,k=1 表示要求递减路径
def dfs(x, y, k):
ans = 0 # 当前路径的最大长度
# 遍历四个方向 (上下左右)
for nx, ny in [(x+1, y), (x-1, y), (x, y+1), (x, y-1)]:
# 如果新坐标超出边界,或者已被访问,或者不满足递增/递减的条件,则跳过
if nx < 0 or ny < 0 or nx >= m or ny >= n or vis[nx][ny] or \
(k == 0 and a[nx][ny] <= a[x][y]) or (k == 1 and a[nx][ny] >= a[x][y]):
continue
# 标记当前坐标为已访问
vis[x][y] = 1
# 递归调用,切换状态 k (k^1 表示状态反转),计算路径长度并取最大值
ans = max(ans, dfs(nx, ny, k ^ 1) + 1)
# 回溯,取消标记,以便尝试其他路径
vis[x][y] = 0
return ans # 返回从当前点出发的最长路径长度
res = 0 # 全局结果,记录最大路径长度
# 遍历矩阵的每一个位置,分别计算从该点出发递增路径和递减路径的最大长度
for i in range(m):
for j in range(n):
res = max(res, dfs(i, j, 1), dfs(i, j, 0)) # 比较递增 (k=0) 和递减 (k=1) 的两种情况
return res # 返回整个矩阵中的最长路径长度
# 测试用例
if __name__ == '__main__':
# 示例 1: 输出 3,最长路径为 1 -> 2 -> 3 或 1 -> 4 -> 3
print(solution(m = 2, n = 2, a = [[1, 2], [4, 3]]) == 3)
# 示例 2: 输出 8,最长路径为 1 -> 6 -> 3 -> 4 -> 7 -> 9 -> 5 -> 10
print(solution(m = 3, n = 3, a = [[10, 1, 6], [5, 9, 3], [7, 2, 4]]) == 8)
# 示例 3: 输出 11,最长路径为 1 -> 2 -> ... -> 16
print(solution(m = 4, n = 4, a = [[8, 3, 2, 1], [4, 7, 6, 5], [12, 11, 10, 9], [16, 15, 14, 13]]) == 11)