A*算法实现最优路径选择

101 阅读8分钟

step1:environment构建。对我们的原始图像划分网格,设置图中的禁行区域

step2:A*算法迭代,计算当前节点的邻节点,并计算每个节点和终点之间的曼哈顿距离

import sys
import time
from typing import List
import random
from PIL import Image, ImageDraw
import numpy as np

"""
Point类是数学坐标系的一个抽象的点,和Node类不是一回事
"""


class Point:
    def __init__(self, x, y) -> None:
        self.x = x
        self.y = y

    # 重载“==”运算符,(x1,y1)==(x2,y2),当且仅当x1=x2,y1=y2
    def __eq__(self, other) -> bool:
        return self.x == other.x and self.y == other.y


class Map2D:
    def __init__(self, height, width, image) -> None:
        self.height = height
        self.width = width
        self.image = Image.open(image)
        self.row = Image.open(image).width / width
        self.column = Image.open(image).height / height

        # width可以看成二维地图的行,height可以看成二维地图的列
        # self.data = [["⬜" for _ in range(width)] for _ in range(height)]
        self.data = [["⬜" for _ in range(int(self.row))] for _ in range(int(self.column))]

    # 将地图数据用文本导出
    def show(self, file_name="output.txt") -> None:
        with open(file_name, 'w', encoding='utf-8') as file:
            for row in self.data:
                file.write(" ".join(row) + '\n')

    # 将地图数据用图片导出
    def export_image(self, file_name="map.png") -> None:
        cell_size = 10
        image = Image.new("RGB", (int(self.row) * cell_size, int(self.column) * cell_size), "white")
        draw = ImageDraw.Draw(image)
        for x in range(int(self.column)):
            for y in range(int(self.row)):
                color = "white"
                if self.data[x][y] == "⬛":
                    color = "black"
                elif self.data[x][y] == "🟥":
                    color = "red"
                elif self.data[x][y] == "🟩":
                    color = "green"
                draw.rectangle([(y * cell_size, x * cell_size), ((y + 1) * cell_size, (x + 1) * cell_size)], fill=color)
        image.save(file_name)

    # 当地图点为⬛,则为障碍物
    def set_obstacle(self, x, y):
        self.data[x][y] = "⬛"

    # 设置起点和终点
    def set_start_end(self, start: Point, end: Point) -> None:
        self.data[start.x][start.y] = "🟥"
        self.data[end.x][end.y] = "🟥"

    # def obstacle_generate(self, ratio: int) -> None:
    #     # 随机放置障碍物
    #     obstacle_cells = int((self.height * self.width) * ratio)  # 障碍物占据40%的格子
    #     for _ in range(obstacle_cells):
    #         x = random.randint(0, map2d.height - 1)
    #         y = random.randint(0, map2d.width - 1)
    #         while (x == start_point.x and y == start_point.y) or (x == end_point.x and y == end_point.y) or \
    #                 map2d.data[x][y] == "⬛":
    #             x = random.randint(0, map2d.height - 1)
    #             y = random.randint(0, map2d.width - 1)
    #         map2d.set_obstacle(x, y)
    def obstacle_generate(self, ratio: int) -> None:
        # 按照像素值设置障碍物
        # obstacle_cells = int((self.height * self.width) * ratio)  # 障碍物占据40%的格子
        # for _ in range(obstacle_cells):
        #     x = random.randint(0, map2d.height - 1)
        #     y = random.randint(0, map2d.width - 1)
        #     while (x == start_point.x and y == start_point.y) or (x == end_point.x and y == end_point.y) or \
        #             map2d.data[x][y] == "⬛":
        #         x = random.randint(0, map2d.height - 1)
        #         y = random.randint(0, map2d.width - 1)
        #     map2d.set_obstacle(x, y)
        # split_image_and_compute_mean()
        for i in range(int(self.row)):
            row = []  # 初始化一行的列表
            for j in range(int(self.column)):
                # 计算每个小图片的左上角坐标
                left = i * self.width
                top = j * self.height
                right = min(left + self.width, self.image.width)  # 处理边界情况
                bottom = min(top + self.height, self.image.height)

                # 裁剪图片
                sub_image = self.image.crop((left, top, right, bottom))
                if sub_image.mode == 'RGBA':
                    # 32 convert 8
                    sub_image = sub_image.convert('RGB')
                # 将图片转换为数组
                array = np.array(sub_image)

                # 计算均值
                mean_value = np.mean(array)
                if mean_value == 0.0:
                    map2d.data[j][i] = "⬛"


"""
    1.ud指的是up and down
    2.rl指的是right and left
"""


class Node:
    def __init__(self, point: Point, endpoint: Point, g: float):  # 初始化中间节点的参数
        self.point = point
        self.endpoint = endpoint
        self.father = None
        self.g = g
        # h取曼哈顿距离,c=|x2-x1|+|y2-y1|
        self.h = (abs(endpoint.x - point.x) + abs(endpoint.y - point.y)) * 10
        self.f = self.g + self.h

    def get_near(self, ud, rl):  # 获取相邻节点
        near_point = Point(self.point.x + rl, self.point.y + ud)
        near_node = Node(near_point, self.endpoint, self.g + (10 if ud == 0 or rl == 0 else 14))
        return near_node


class AStar:
    def __init__(self, start: Point, end: Point, map2d: Map2D):  # 初始化A*算法的参数
        self.path = []
        self.closed_list = []
        self.open_list = []
        self.start = start
        self.end = end
        self.map2d = map2d

    # 从open_list里面找到一个代价最小的节点
    def select_current(self) -> Node:
        min_f = sys.maxsize
        node_temp = None
        for node in self.open_list:
            if node.f < min_f:
                min_f = node.f
                node_temp = node
        return node_temp

    def is_in_open_list(self, node: Node) -> bool:  # 判断节点是否在待检测队列中
        return any([open_node.point == node.point for open_node in self.open_list])

    def is_in_closed_list(self, node: Node) -> bool:  # 判断节点是否在已检测队列中
        return any([closed_node.point == node.point for closed_node in self.closed_list])

    def is_obstacle(self, node: Node) -> bool:  # 判断节点是否是障碍物
        # 检查节点坐标是否在地图数据的边界内
        if node.point.x < 0 or node.point.x >= len(self.map2d.data):
            return False
        if node.point.y < 0 or node.point.y >= len(self.map2d.data[0]):
            return False
        # 检查节点是否是障碍物
        return self.map2d.data[node.point.x][node.point.y] == "⬛"

    """
    这个函数是A*算法的核心函数,找到当前节点代价最小的邻点
    用list来当作是队列的数据结构,存放探测过或者未被探测的节点,以此来进行路径探索
    在路径探索中节点有三种状态
    状态1.加入了队列并且已经检测了,这个单独用一个Close_list队列存放
    状态2.加入了队列但是还没有检测,这个用Open_list队列存放
    状态3.还没有被加入队列
    """


    def explore_neighbors(self, current_node: Node) -> bool:
        up = (0, 1)  # 上
        down = (0, -1)  # 下
        right = (1, 0)  # 右
        left = (-1, 0)  # 左
        top_right = (1, 1)  # 右上
        top_left = (-1, 1)  # 左上
        Bottom_right = (1, -1)  # 右下
        Bottom_left = (-1, -1)  # 左下
        directions = [up, down, right, left, top_right, top_left, Bottom_right, Bottom_left]
        for direction in directions:
            ud, rl = direction
            # current_neighbor是当前节点的邻点
            current_neighbor = current_node.get_near(ud, rl)
            # 如果检测到的节点是终点,就没必要接着往下探索了,直接退出循环,结束这个函数
            if current_neighbor.point == self.end:
                return True
            # 判断一下邻点是不是已经检测或者是障碍物,如果是,就跳过这个邻点
            if self.is_in_closed_list(current_neighbor) or self.is_obstacle(current_neighbor):
                continue
            if self.is_in_open_list(current_neighbor):
                """
                作用:在open_list中找到第一个与current_neighbor相同(坐标相同)的节点
                这里有两个值得注意的点
                1.在open_list中,可能有多个与current_neighbor相同(坐标相同)的节点,
                出现这种情况是因为同一个节点,是可以通过多条不同的路径抵达的(意思就是g值不同)
                比如说节点C是当前节点,点A与节点B都能抵达节点C且g值都相同,那么节点C此时在open_list就会被添加两次

                2.previous_current_neighbor是取的在open_list中与current_neighbor相同(坐标相同)的节点中
                他们唯一的区别就是g值不同但因为有多个匹配,因此这里用next函数只取一次即可
                """

                previous_current_neighbor = next(
                    open_node for open_node in self.open_list if open_node.point == current_neighbor.point)

                """
                这时就要比较current_neighbor与previous_current_neighbor的代价了,
                假如我在本次的路径探索到的current_neighbor要比我之前的路径探索到的previous_current_neighbor的代价要小
                (这里时刻注意,current_neighbor与previous_current_neighbor是坐标相同的),那么我就要更新previous_current_neighbor的代价
                """
                if current_neighbor.f < previous_current_neighbor.f:
                    # 更新父节点
                    previous_current_neighbor.father = current_node
                    # 更新g值
                    previous_current_neighbor.g = current_neighbor.g
            else:
                # 对应状态3,直接入队
                current_neighbor.father = current_node
                self.open_list.append(current_neighbor)
        return False

    def find_path(self):
        start_node = Node(point=self.start, endpoint=self.end, g=0)
        self.open_list.append(start_node)
        while True:
            # 从open_list里面取出一个代价值最小节点
            current_node = self.select_current()
            if current_node is None:
                return None
            # 取出来后,从open_list里面删除,添加到closed_list里面
            self.open_list.remove(current_node)
            self.closed_list.append(current_node)
            # 当current_node是终点时,explore_neighbors函数会返回一个True
            if current_node.point == self.end or self.explore_neighbors(current_node):
                while current_node.father is not None:
                    self.path.insert(0, current_node.point)
                    # 这里其实就是相当于遍历一个链表
                    current_node = current_node.father
                return self.path


def split_image_and_compute_mean(image_path, row, column):
    # 打开图片
    image = Image.open(image_path)

    # 将图片转换为灰度图,以便计算均值
    image = image.convert('L')

    # 获取图片的宽度和高度
    width, height = image.size

    # 计算每个小图片的宽度和高度
    sub_width = width // row
    sub_height = height // column

    # 初始化一个二维列表来存储每个小图片的均值
    mean_values = []

    # 拆分图片并计算均值
    for i in range(n):
        row = []  # 初始化一行的列表
        for j in range(n):
            # 计算每个小图片的左上角坐
            left = j * sub_width
            top = i * sub_height
            right = min(left + sub_width, width)  # 处理边界情况
            bottom = min(top + sub_height, height)

            # 裁剪图片
            sub_image = image.crop((left, top, right, bottom))

            # 将图片转换为数组
            array = np.array(sub_image)

            # 计算均值
            mean_value = np.mean(array)

            # 将均值添加到行列表中
            row.append(mean_value)
        mean_values.append(row)  # 将行列表添加到二维列表中

    return mean_values


def array_to_image(array):
    # 确定图片的宽度和高度
    width = len(array[0])
    height = len(array)

    # 创建一个新的白色背景图片
    image = Image.new('RGB', (width, height), 'white')
    draw = ImageDraw.Draw(image)

    # 定义颜色映射
    color_map = {'⬛': (0, 0, 0), '⬜': (255, 255, 255), '🟥': (255, 0, 0)}

    # 遍历数组并绘制每个方块
    for y in range(height):
        for x in range(width):
            color = color_map.get(array[y][x], (0, 0, 0))  # 默认黑色
            draw.rectangle([x, y, x + 1, y + 1], fill=color)

    # 保存图片
    image.save('org_ceil_image.png')


def alternat_point_color(ceil_w, ceil_h, img_data, point, image):
    # alternate end and start point color
    left = point.x * ceil_w
    top = point.y * ceil_h
    right = min(left + ceil_w, image.width)  # 处理边界情况
    bottom = min(top + ceil_h, image.height)
    for r in range(top, bottom):
        for c in range(left, right):
            img_data[r, c] = (255, 0, 0)
    return img_data


def path_plot_in_org_img(ceil_h, ceil_w, img_path, target_point_obj, new_color, start_point, end_point):
    image = Image.open(img_path)
    split_row = image.width / ceil_w
    split_colum = image.height / ceil_h

    target_point_list = []
    for target_point in target_point_obj:
        target_point_list.append((target_point.x, target_point.y))


    img_data = image.load()  # 加载图片数据
    for i in range(int(split_row)):
        for j in range(int(split_colum)):
            # 计算每个小图片的左上角坐
            left = j * ceil_w
            top = i * ceil_h
            right = min(left + ceil_w, image.width)  # 处理边界情况
            bottom = min(top + ceil_h, image.height)

            if (i, j) in target_point_list:
                for y in range(top, bottom):
                    for x in range(left, right):
                        img_data[x, y] = new_color

    # alternate point end and start point color
    alternat_point_color(ceil_w, ceil_h, img_data, start_point, image)
    alternat_point_color(ceil_w, ceil_h, img_data, end_point, image)



    image.save('./path_in_org_img.png')
    print('--------------path saved in org image!!!----------------')








if __name__ == "__main__":
    # 创建地图
    ceil_h = 15
    ceil_w = 15
    map2d = Map2D(ceil_h, ceil_w, r'D:\YJJY_Proj\path_select\road_select.png')
    # map2d = split_image_and_compute_mean(r'D:\PycharmProjects\YJJY_Project\AStar_algorithm\test1.png', 17, 19)
    # 设置起点和终点
    map2d.obstacle_generate(0.1)
    start_point = Point(136, 97)
    end_point = Point(311, 221)
    map2d.set_start_end(start_point, end_point)
    # save org ciel image
    array_to_image(map2d.data)



    # 运行A*算法
    start_time = time.time()
    a_star = AStar(start_point, end_point, map2d)
    path = a_star.find_path()
    end_time = time.time()
    # 打印结果
    if path:
        print("找到最佳路径:")
        for point in path:
            map2d.data[point.x][point.y] = "🟩"
        map2d.export_image("result.png")

        # save path in org image
        path_plot_in_org_img(ceil_h, ceil_w, r'D:\YJJY_Proj\path_select\org_complete.jpg', path, (0, 255, 0), start_point,
                             end_point)

    else:
        print("未找到路径!")

    # 打印运行时间
    print("程序运行时间:", end_time - start_time, "秒")