A星算法的Python实现

148 阅读2分钟

考虑这样一个问题,从图中Start点开始出发,每次只能移动一步,向上下左右的某个相邻点移动,找到一条最短路径,移动到End点,如果遇到黑色点,表示这个位置是墙无法移动,需要绕开。

image.png

这类题目的典型做法可以是BFS或者DFS。

这里使用一种叫做A星算法的方法来解决。

核心思想是记录下当前可供查找的点,距离起点一共实际走了几步,以及当前到终点的直线距离,两个值的和作为一个参考指标,每次都找到这个指标最小的点作为探索点,向四周进行发散探索,直到找到了终点。

这个算法的效率会比BFS或者DFS快不少。

以下是python的简易实现:

import math

from pydantic import BaseModel
from typing import List


class Point(BaseModel):
    x: int
    y: int


class Node:
    def __init__(self, src_node, p: Point, start_real_distance, end: Point):
        self.src_node = src_node
        self.x = p.x
        self.y = p.y
        self.start_real_distance = start_real_distance
        self.end_line_distance = self._calc_end_line_distance_(end.x, end.y)
        self.total_distance = round(self.start_real_distance + self.end_line_distance, 4)

    def _calc_end_line_distance_(self, end_x, end_y):
        line_distance = round(math.sqrt((self.x - end_x) ** 2 + (self.y - end_y) ** 2), 4)
        return line_distance

    def __repr__(self):
        return f"## x={self.x}, y={self.y}, total={self.total_distance}, line={self.end_line_distance}, src_x={self.src_node.x if self.src_node else None}, src_y={self.src_node.y if self.src_node else None}"


class Solution:
    def find_path(self, start, end: Point, blocks: List[Point]):
        points = []
        nodes: List[Node] = []
        end_node = Node(None, end, math.inf, end)
        added_node_points = set()

        def astar(src_node: Node):
            nonlocal points
            nonlocal nodes
            nonlocal end_node
            nonlocal blocks
            nonlocal end
            p = Point(x=src_node.x, y=src_node.y)
            if p.x == end.x and p.y == end.y:
                end_node = src_node
                return
            points.append(str(p.x) + "|" + str(p.y))
            added_node_points.add(str(p.x) + "|" + str(p.y))
            new_search_points = [
                Point(x=p.x, y=p.y + 1),
                Point(x=p.x + 1, y=p.y),
                Point(x=p.x, y=p.y - 1),
                Point(x=p.x - 1, y=p.y),
            ]
            for next_point in new_search_points:
                if next_point.x < 0 or next_point.x > 3 or next_point.y < 0 or next_point.y > 6:
                    continue
                if next_point in blocks:
                    continue
                if src_node.x == next_point.x and src_node.y == next_point.y:
                    continue
                k = str(next_point.x) + "|" + str(next_point.y)
                if k in points or k in added_node_points:
                    continue
                added_node_points.add(k)
                x = Node(src_node, next_point, src_node.start_real_distance + 1, end)
                nodes.append(x)
            nearliest = math.inf
            next_select = None
            for n in nodes:
                nearliest = min(n.total_distance, nearliest)
            for index, n in enumerate(nodes):
                if n.total_distance == nearliest:
                    next_select = n
                    nodes.pop(index)
                    break
            astar(next_select)

        start_node = Node(None, start, 0, end)
        astar(start_node)

        return end_node


s = Solution()
path_end = s.find_path(Point(x=2, y=1), Point(x=0, y=6), blocks=[Point(x=1, y=3), Point(x=2, y=3)])
# 从终点倒序打印至起点
while path_end:
    print(path_end)
    path_end = path_end.src_node

代码还有挺多优化空间,比如如何在所有候选点里找到下一个探索点。

应用场景

可以应用在很多跟图、路径相关的场景中,比如导航地图应用,网游里面寻路场景等。

根据最后生成的路线,还可以统计路线中一共有多少个拐点,直线线段长度等,这都是后续可以扩展的应用。