深入浅出 A* 算法 (三)-代码案例

67 阅读8分钟

 在深入浅出 A* 算法 (二)中,我讲到了由障碍物的情况下,A想法是怎么寻找到最优路径的,并把他的寻找逻辑梳理了一遍,那么接下来我们从代码层面上来看,A算法是怎么实现的

假设地图是这个样子,我们采用typescript作为开发语言

我把这个地图转为json结构

interface Node {
    x: number;
    y: number;
    connections: string[];
}

interface Graph {
    [key: string]: Node;
}

// 示例使用
const graph: Graph = {
  "A1": { "x": 0, "y": 0, "connections": ["A2", "B1", "B2"] },
  "A2": { "x": 0, "y": 1, "connections": ["A1", "A3", "B1", "B2"] },
  "A3": { "x": 0, "y": 2, "connections": ["A2", "A4", "B2", "B4"] },
  "A4": { "x": 0, "y": 3, "connections": ["A3", "A5", "B4", "B5"] },
  "A5": { "x": 0, "y": 4, "connections": ["A4", "B4", "B5"] },
  "B1": { "x": 1, "y": 0, "connections": ["A1", "A2", "B2", "C1", "C2"] },
  "B2": { "x": 1, "y": 1, "connections": ["A1", "A2", "A3", "B1", "C1", "C2"] },
  "B4": { "x": 1, "y": 3, "connections": ["A3", "A4", "A5", "B5", "C4", "C5"] },
  "B5": { "x": 1, "y": 4, "connections": ["A4", "A5", "B4", "C4", "C5"] },
  "C1": { "x": 2, "y": 0, "connections": ["B1", "B2", "C2", "D1", "D2"] },
  "C2": { "x": 2, "y": 1, "connections": ["B1", "B2", "C1", "D1", "D2"] },
  "C4": { "x": 2, "y": 3, "connections": ["B4", "B5", "C5", "D4", "D5"] },
  "C5": { "x": 2, "y": 4, "connections": ["B4", "B5", "C4", "D4", "D5"] },
  "D1": { "x": 3, "y": 0, "connections": ["C1", "C2", "D2", "E1", "E2"] },
  "D2": { "x": 3, "y": 1, "connections": ["C1", "C2", "D1", "E1", "E2", "E3"] },
  "D4": { "x": 3, "y": 3, "connections": ["C4", "C5", "D5", "E3", "E4", "E5"] },
  "D5": { "x": 3, "y": 4, "connections": ["C4", "C5", "D4", "E4", "E5"] },
  "E1": { "x": 4, "y": 0, "connections": ["D1", "D2", "E2"] },
  "E2": { "x": 4, "y": 1, "connections": ["D1", "D2", "E1", "E3"] },
  "E3": { "x": 4, "y": 2, "connections": ["D2", "D4", "E2", "E4"] },
  "E4": { "x": 4, "y": 3, "connections": ["D4", "D5", "E3", "E5"] },
  "E5": { "x": 4, "y": 4, "connections": ["D4", "D5", "E4"] }
};

然后还是采用曼哈顿距离作为启发式算法

使用欧几里得距离作为实际距离算法

function manhattanDistance(nodeA: Node, nodeB: Node): number {
    return Math.abs(nodeA.x - nodeB.x) + Math.abs(nodeA.y - nodeB.y);
}

function euclideanDistance(nodeA: Node, nodeB: Node): number {
    return Math.sqrt(Math.pow(nodeA.x - nodeB.x, 2) + Math.pow(nodeA.y - nodeB.y, 2));
}

接下来给我们的元素定义一个类型

interface AStarNode {
    id: string;
    parent: AStarNode | null;
    g: number; // 从起点到当前节点的实际成本(欧几里得距离)
    h: number; // 从当前节点到目标的估计成本(曼哈顿距离)
    f: number; // g + h
}

然后我们一步一步实现这个算法

function aStar(graph: Graph, start: string, end: string): string[] {
    // 初始化开放列表和关闭列表
    const openList: AStarNode[] = [];
    const closedList: Set<string> = new Set();

    // 获取起点和终点节点
    const startNode = graph[start];
    const endNode = graph[end];

    if (!startNode || !endNode) {
        throw new Error("Start or end node not found in graph");
    }
}

深入浅出 A* 算法 (一)里面讲过,我们需要定义两个列表,开放列表和关闭列表

然后获取我们的起点和终点

如果起点和终点不存在,则直接报错

以上就是前置条件

接下来我们将起点存入到开放列表中

function aStar(graph: Graph, start: string, end: string): string[] {
    // 初始化开放列表和关闭列表
    const openList: AStarNode[] = [];
    const closedList: Set<string> = new Set();

    // 获取起点和终点节点
    const startNode = graph[start];
    const endNode = graph[end];

    if (!startNode || !endNode) {
        throw new Error("Start or end node not found in graph");
    }

    // 将起始节点添加到开放列表
    openList.push({
        id: start,
        parent: null,
        g: 0,
        h: manhattanDistance(startNode, endNode),
        f: manhattanDistance(startNode, endNode)
    });
}

接下来遍历开放列表

while (openList.length > 0) {

}

第一步是寻找到开放列表中 f 值最小的点位

while (openList.length > 0) {
    // 找到开放列表中f值最小的节点
    let currentIndex = 0;
    for (let i = 1; i < openList.length; i++) {
        if (openList[i].f < openList[currentIndex].f) {
            currentIndex = i;
        }
    }

    const currentNode = openList[currentIndex];
}

如果currentNode是目标点位,那么我们就整合关闭列表里面的路径,并返回路径,在第二章节我们就讲到,返回路径需要当前节点,以及当前节点的父节点,然后反推到开始节点

while (openList.length > 0) {
        // 找到开放列表中f值最小的节点
        let currentIndex = 0;
        for (let i = 1; i < openList.length; i++) {
            if (openList[i].f < openList[currentIndex].f) {
                currentIndex = i;
            }
        }
        
        const currentNode = openList[currentIndex];

        // 如果当前节点是目标节点,重构路径并返回
        if (currentNode.id === end) {
            const path: string[] = [];
            let current: AStarNode | null = currentNode;
            while (current !== null) {
                path.unshift(current.id);
                current = current.parent;
            }
            return path;
        }
}

如果currentNode只是普通的节点,那么我们只需要将其从开放列表中删除,并放入到关闭列表中

// 将当前节点从开放列表移到关闭列表
openList.splice(currentIndex, 1);
closedList.add(currentNode.id);

再然后我们就需要遍历该节点的所有邻居

// 遍历当前节点的所有邻居
const currentGraphNode = graph[currentNode.id];
for (const neighborId of currentGraphNode.connections) {

}

然后我们做以下操作

  1. 若邻居已经在关闭列表中,我们就跳过
  2. 然后计算邻居与当前节点的 g 值
  3. 再检查邻居是否已经在开发列表中
// 遍历当前节点的所有邻居
const currentGraphNode = graph[currentNode.id];
for (const neighborId of currentGraphNode.connections) {
    // 跳过已在关闭列表中的邻居
    if (closedList.has(neighborId)) {
        continue;
    }

    const neighborNode = graph[neighborId];
    // 计算从起点经过当前节点到邻居的实际成本(欧几里得距离)
    const gScore = currentNode.g + euclideanDistance(currentGraphNode, neighborNode);

    // 检查邻居是否已在开放列表中
    let neighborInOpenList = openList.find(node => node.id === neighborId);
}

接下来会有两种情况

  • 如果邻居不在开发列表中,我们就把这个邻居进入到开发列表中,并且把邻居的parent赋值为当前节点
  • 如果邻居在开发列表中,我们就更新这个邻居的 g 值和 f 值,并把邻居的parent改为当前节点
if (!neighborInOpenList) {
    // 不在开放列表中,添加新节点
    neighborInOpenList = {
        id: neighborId,
        parent: currentNode,
        g: gScore,
        h: manhattanDistance(neighborNode, endNode),
        f: gScore + manhattanDistance(neighborNode, endNode)
    };
    openList.push(neighborInOpenList);
} else if (gScore < neighborInOpenList.g) {
    // 已在开放列表中但找到更优路径,更新节点
    neighborInOpenList.g = gScore;
    neighborInOpenList.f = gScore + neighborInOpenList.h;
    neighborInOpenList.parent = currentNode;
            }

接下来我们看完整代码

Typescript

interface Node {
    x: number;
    y: number;
    connections: string[];
}

interface Graph {
    [key: string]: Node;
}

interface AStarNode {
    id: string;
    parent: AStarNode | null;
    g: number; // 从起点到当前节点的实际成本(欧几里得距离)
    h: number; // 从当前节点到目标的估计成本(曼哈顿距离)
    f: number; // g + h
}

function euclideanDistance(nodeA: Node, nodeB: Node): number {
    return Math.sqrt(Math.pow(nodeA.x - nodeB.x, 2) + Math.pow(nodeA.y - nodeB.y, 2));
}

function manhattanDistance(nodeA: Node, nodeB: Node): number {
    return Math.abs(nodeA.x - nodeB.x) + Math.abs(nodeA.y - nodeB.y);
}

function aStar(graph: Graph, start: string, end: string): string[] {
    // 初始化开放列表和关闭列表
    const openList: AStarNode[] = [];
    const closedList: Set<string> = new Set();
    
    // 获取起点和终点节点
    const startNode = graph[start];
    const endNode = graph[end];
    
    if (!startNode || !endNode) {
        throw new Error("Start or end node not found in graph");
    }
    
    // 将起始节点添加到开放列表
    openList.push({
        id: start,
        parent: null,
        g: 0,
        h: manhattanDistance(startNode, endNode),
        f: manhattanDistance(startNode, endNode)
    });
    
    while (openList.length > 0) {
        // 找到开放列表中f值最小的节点
        let currentIndex = 0;
        for (let i = 1; i < openList.length; i++) {
            if (openList[i].f < openList[currentIndex].f) {
                currentIndex = i;
            }
        }
        
        const currentNode = openList[currentIndex];
        
        // 如果当前节点是目标节点,重构路径并返回
        if (currentNode.id === end) {
            const path: string[] = [];
            let current: AStarNode | null = currentNode;
            while (current !== null) {
                path.unshift(current.id);
                current = current.parent;
            }
            return path;
        }
        
        // 将当前节点从开放列表移到关闭列表
        openList.splice(currentIndex, 1);
        closedList.add(currentNode.id);
        
        // 遍历当前节点的所有邻居
        const currentGraphNode = graph[currentNode.id];
        for (const neighborId of currentGraphNode.connections) {
            // 跳过已在关闭列表中的邻居
            if (closedList.has(neighborId)) {
                continue;
            }
            
            const neighborNode = graph[neighborId];
            // 计算从起点经过当前节点到邻居的实际成本(欧几里得距离)
            const gScore = currentNode.g + euclideanDistance(currentGraphNode, neighborNode);
            
            // 检查邻居是否已在开放列表中
            let neighborInOpenList = openList.find(node => node.id === neighborId);
            
            if (!neighborInOpenList) {
                // 不在开放列表中,添加新节点
                neighborInOpenList = {
                    id: neighborId,
                    parent: currentNode,
                    g: gScore,
                    h: manhattanDistance(neighborNode, endNode),
                    f: gScore + manhattanDistance(neighborNode, endNode)
                };
                openList.push(neighborInOpenList);
            } else if (gScore < neighborInOpenList.g) {
                // 已在开放列表中但找到更优路径,更新节点
                neighborInOpenList.g = gScore;
                neighborInOpenList.f = gScore + neighborInOpenList.h;
                neighborInOpenList.parent = currentNode;
            }
        }
    }
    
    // 开放列表为空但未找到路径
    throw new Error("No path found from start to end");
}

另附一个python的A*算法

Python

import math
import heapq

class Node:
    def __init__(self, x, y, connections):
        self.x = x
        self.y = y
        self.connections = connections

class AStarNode:
    def __init__(self, node_id, parent=None, g=0, h=0):
        self.node_id = node_id
        self.parent = parent
        self.g = g  # 从起点到当前节点的实际成本(欧几里得距离)
        self.h = h  # 从当前节点到目标的估计成本(曼哈顿距离)
        self.f = g + h  # 总成本
    
    def __lt__(self, other):
        return self.f < other.f

def euclidean_distance(node_a, node_b):
    return math.sqrt((node_a.x - node_b.x)**2 + (node_a.y - node_b.y)**2)

def manhattan_distance(node_a, node_b):
    return abs(node_a.x - node_b.x) + abs(node_a.y - node_b.y)

def a_star(graph, start, end):
    if start not in graph or end not in graph:
        raise ValueError("Start or end node not found in graph")
    
    # 初始化开放列表和关闭集合
    open_list = []
    closed_set = set()
    
    # 创建起始节点
    start_node = graph[start]
    end_node = graph[end]
    
    start_astar_node = AStarNode(
        node_id=start,
        g=0,
        h=manhattan_distance(start_node, end_node)
    )
    
    heapq.heappush(open_list, start_astar_node)
    
    while open_list:
        # 获取f值最小的节点
        current = heapq.heappop(open_list)
        
        # 如果找到目标节点,重构路径
        if current.node_id == end:
            path = []
            while current:
                path.append(current.node_id)
                current = current.parent
            return path[::-1]  # 反转路径,从起点到终点
        
        # 将当前节点添加到关闭集合
        closed_set.add(current.node_id)
        
        # 遍历所有邻居
        current_node = graph[current.node_id]
        for neighbor_id in current_node.connections:
            if neighbor_id in closed_set:
                continue
                
            neighbor_node = graph[neighbor_id]
            # 计算从起点经过当前节点到邻居的成本
            g_score = current.g + euclidean_distance(current_node, neighbor_node)
            
            # 检查邻居是否在开放列表中
            in_open_list = False
            for node in open_list:
                if node.node_id == neighbor_id:
                    in_open_list = True
                    # 如果找到更优路径,更新节点
                    if g_score < node.g:
                        node.g = g_score
                        node.f = g_score + node.h
                        node.parent = current
                        # 重新堆化以维持堆属性
                        heapq.heapify(open_list)
                    break
            
            # 如果不在开放列表中,添加新节点
            if not in_open_list:
                new_node = AStarNode(
                    node_id=neighbor_id,
                    parent=current,
                    g=g_score,
                    h=manhattan_distance(neighbor_node, end_node)
                )
                heapq.heappush(open_list, new_node)
    
    # 如果没有找到路径
    raise ValueError("No path found from start to end")