step1:environment构建。对我们的原始图像划分网格,设置图中的禁行区域
step2:A*算法迭代,计算当前节点的邻节点,并计算每个节点和终点之间的曼哈顿距离
import sys
import time
from typing import List
import random
from PIL import Image, ImageDraw
import numpy as np
"""
Point类是数学坐标系的一个抽象的点,和Node类不是一回事
"""
class Point:
def __init__(self, x, y) -> None:
self.x = x
self.y = y
# 重载“==”运算符,(x1,y1)==(x2,y2),当且仅当x1=x2,y1=y2
def __eq__(self, other) -> bool:
return self.x == other.x and self.y == other.y
class Map2D:
def __init__(self, height, width, image) -> None:
self.height = height
self.width = width
self.image = Image.open(image)
self.row = Image.open(image).width / width
self.column = Image.open(image).height / height
# width可以看成二维地图的行,height可以看成二维地图的列
# self.data = [["⬜" for _ in range(width)] for _ in range(height)]
self.data = [["⬜" for _ in range(int(self.row))] for _ in range(int(self.column))]
# 将地图数据用文本导出
def show(self, file_name="output.txt") -> None:
with open(file_name, 'w', encoding='utf-8') as file:
for row in self.data:
file.write(" ".join(row) + '\n')
# 将地图数据用图片导出
def export_image(self, file_name="map.png") -> None:
cell_size = 10
image = Image.new("RGB", (int(self.row) * cell_size, int(self.column) * cell_size), "white")
draw = ImageDraw.Draw(image)
for x in range(int(self.column)):
for y in range(int(self.row)):
color = "white"
if self.data[x][y] == "⬛":
color = "black"
elif self.data[x][y] == "🟥":
color = "red"
elif self.data[x][y] == "🟩":
color = "green"
draw.rectangle([(y * cell_size, x * cell_size), ((y + 1) * cell_size, (x + 1) * cell_size)], fill=color)
image.save(file_name)
# 当地图点为⬛,则为障碍物
def set_obstacle(self, x, y):
self.data[x][y] = "⬛"
# 设置起点和终点
def set_start_end(self, start: Point, end: Point) -> None:
self.data[start.x][start.y] = "🟥"
self.data[end.x][end.y] = "🟥"
# def obstacle_generate(self, ratio: int) -> None:
# # 随机放置障碍物
# obstacle_cells = int((self.height * self.width) * ratio) # 障碍物占据40%的格子
# for _ in range(obstacle_cells):
# x = random.randint(0, map2d.height - 1)
# y = random.randint(0, map2d.width - 1)
# while (x == start_point.x and y == start_point.y) or (x == end_point.x and y == end_point.y) or \
# map2d.data[x][y] == "⬛":
# x = random.randint(0, map2d.height - 1)
# y = random.randint(0, map2d.width - 1)
# map2d.set_obstacle(x, y)
def obstacle_generate(self, ratio: int) -> None:
# 按照像素值设置障碍物
# obstacle_cells = int((self.height * self.width) * ratio) # 障碍物占据40%的格子
# for _ in range(obstacle_cells):
# x = random.randint(0, map2d.height - 1)
# y = random.randint(0, map2d.width - 1)
# while (x == start_point.x and y == start_point.y) or (x == end_point.x and y == end_point.y) or \
# map2d.data[x][y] == "⬛":
# x = random.randint(0, map2d.height - 1)
# y = random.randint(0, map2d.width - 1)
# map2d.set_obstacle(x, y)
# split_image_and_compute_mean()
for i in range(int(self.row)):
row = [] # 初始化一行的列表
for j in range(int(self.column)):
# 计算每个小图片的左上角坐标
left = i * self.width
top = j * self.height
right = min(left + self.width, self.image.width) # 处理边界情况
bottom = min(top + self.height, self.image.height)
# 裁剪图片
sub_image = self.image.crop((left, top, right, bottom))
if sub_image.mode == 'RGBA':
# 32 convert 8
sub_image = sub_image.convert('RGB')
# 将图片转换为数组
array = np.array(sub_image)
# 计算均值
mean_value = np.mean(array)
if mean_value == 0.0:
map2d.data[j][i] = "⬛"
"""
1.ud指的是up and down
2.rl指的是right and left
"""
class Node:
def __init__(self, point: Point, endpoint: Point, g: float): # 初始化中间节点的参数
self.point = point
self.endpoint = endpoint
self.father = None
self.g = g
# h取曼哈顿距离,c=|x2-x1|+|y2-y1|
self.h = (abs(endpoint.x - point.x) + abs(endpoint.y - point.y)) * 10
self.f = self.g + self.h
def get_near(self, ud, rl): # 获取相邻节点
near_point = Point(self.point.x + rl, self.point.y + ud)
near_node = Node(near_point, self.endpoint, self.g + (10 if ud == 0 or rl == 0 else 14))
return near_node
class AStar:
def __init__(self, start: Point, end: Point, map2d: Map2D): # 初始化A*算法的参数
self.path = []
self.closed_list = []
self.open_list = []
self.start = start
self.end = end
self.map2d = map2d
# 从open_list里面找到一个代价最小的节点
def select_current(self) -> Node:
min_f = sys.maxsize
node_temp = None
for node in self.open_list:
if node.f < min_f:
min_f = node.f
node_temp = node
return node_temp
def is_in_open_list(self, node: Node) -> bool: # 判断节点是否在待检测队列中
return any([open_node.point == node.point for open_node in self.open_list])
def is_in_closed_list(self, node: Node) -> bool: # 判断节点是否在已检测队列中
return any([closed_node.point == node.point for closed_node in self.closed_list])
def is_obstacle(self, node: Node) -> bool: # 判断节点是否是障碍物
# 检查节点坐标是否在地图数据的边界内
if node.point.x < 0 or node.point.x >= len(self.map2d.data):
return False
if node.point.y < 0 or node.point.y >= len(self.map2d.data[0]):
return False
# 检查节点是否是障碍物
return self.map2d.data[node.point.x][node.point.y] == "⬛"
"""
这个函数是A*算法的核心函数,找到当前节点代价最小的邻点
用list来当作是队列的数据结构,存放探测过或者未被探测的节点,以此来进行路径探索
在路径探索中节点有三种状态
状态1.加入了队列并且已经检测了,这个单独用一个Close_list队列存放
状态2.加入了队列但是还没有检测,这个用Open_list队列存放
状态3.还没有被加入队列
"""
def explore_neighbors(self, current_node: Node) -> bool:
up = (0, 1) # 上
down = (0, -1) # 下
right = (1, 0) # 右
left = (-1, 0) # 左
top_right = (1, 1) # 右上
top_left = (-1, 1) # 左上
Bottom_right = (1, -1) # 右下
Bottom_left = (-1, -1) # 左下
directions = [up, down, right, left, top_right, top_left, Bottom_right, Bottom_left]
for direction in directions:
ud, rl = direction
# current_neighbor是当前节点的邻点
current_neighbor = current_node.get_near(ud, rl)
# 如果检测到的节点是终点,就没必要接着往下探索了,直接退出循环,结束这个函数
if current_neighbor.point == self.end:
return True
# 判断一下邻点是不是已经检测或者是障碍物,如果是,就跳过这个邻点
if self.is_in_closed_list(current_neighbor) or self.is_obstacle(current_neighbor):
continue
if self.is_in_open_list(current_neighbor):
"""
作用:在open_list中找到第一个与current_neighbor相同(坐标相同)的节点
这里有两个值得注意的点
1.在open_list中,可能有多个与current_neighbor相同(坐标相同)的节点,
出现这种情况是因为同一个节点,是可以通过多条不同的路径抵达的(意思就是g值不同)
比如说节点C是当前节点,点A与节点B都能抵达节点C且g值都相同,那么节点C此时在open_list就会被添加两次
2.previous_current_neighbor是取的在open_list中与current_neighbor相同(坐标相同)的节点中
他们唯一的区别就是g值不同但因为有多个匹配,因此这里用next函数只取一次即可
"""
previous_current_neighbor = next(
open_node for open_node in self.open_list if open_node.point == current_neighbor.point)
"""
这时就要比较current_neighbor与previous_current_neighbor的代价了,
假如我在本次的路径探索到的current_neighbor要比我之前的路径探索到的previous_current_neighbor的代价要小
(这里时刻注意,current_neighbor与previous_current_neighbor是坐标相同的),那么我就要更新previous_current_neighbor的代价
"""
if current_neighbor.f < previous_current_neighbor.f:
# 更新父节点
previous_current_neighbor.father = current_node
# 更新g值
previous_current_neighbor.g = current_neighbor.g
else:
# 对应状态3,直接入队
current_neighbor.father = current_node
self.open_list.append(current_neighbor)
return False
def find_path(self):
start_node = Node(point=self.start, endpoint=self.end, g=0)
self.open_list.append(start_node)
while True:
# 从open_list里面取出一个代价值最小节点
current_node = self.select_current()
if current_node is None:
return None
# 取出来后,从open_list里面删除,添加到closed_list里面
self.open_list.remove(current_node)
self.closed_list.append(current_node)
# 当current_node是终点时,explore_neighbors函数会返回一个True
if current_node.point == self.end or self.explore_neighbors(current_node):
while current_node.father is not None:
self.path.insert(0, current_node.point)
# 这里其实就是相当于遍历一个链表
current_node = current_node.father
return self.path
def split_image_and_compute_mean(image_path, row, column):
# 打开图片
image = Image.open(image_path)
# 将图片转换为灰度图,以便计算均值
image = image.convert('L')
# 获取图片的宽度和高度
width, height = image.size
# 计算每个小图片的宽度和高度
sub_width = width // row
sub_height = height // column
# 初始化一个二维列表来存储每个小图片的均值
mean_values = []
# 拆分图片并计算均值
for i in range(n):
row = [] # 初始化一行的列表
for j in range(n):
# 计算每个小图片的左上角坐
left = j * sub_width
top = i * sub_height
right = min(left + sub_width, width) # 处理边界情况
bottom = min(top + sub_height, height)
# 裁剪图片
sub_image = image.crop((left, top, right, bottom))
# 将图片转换为数组
array = np.array(sub_image)
# 计算均值
mean_value = np.mean(array)
# 将均值添加到行列表中
row.append(mean_value)
mean_values.append(row) # 将行列表添加到二维列表中
return mean_values
def array_to_image(array):
# 确定图片的宽度和高度
width = len(array[0])
height = len(array)
# 创建一个新的白色背景图片
image = Image.new('RGB', (width, height), 'white')
draw = ImageDraw.Draw(image)
# 定义颜色映射
color_map = {'⬛': (0, 0, 0), '⬜': (255, 255, 255), '🟥': (255, 0, 0)}
# 遍历数组并绘制每个方块
for y in range(height):
for x in range(width):
color = color_map.get(array[y][x], (0, 0, 0)) # 默认黑色
draw.rectangle([x, y, x + 1, y + 1], fill=color)
# 保存图片
image.save('org_ceil_image.png')
def alternat_point_color(ceil_w, ceil_h, img_data, point, image):
# alternate end and start point color
left = point.x * ceil_w
top = point.y * ceil_h
right = min(left + ceil_w, image.width) # 处理边界情况
bottom = min(top + ceil_h, image.height)
for r in range(top, bottom):
for c in range(left, right):
img_data[r, c] = (255, 0, 0)
return img_data
def path_plot_in_org_img(ceil_h, ceil_w, img_path, target_point_obj, new_color, start_point, end_point):
image = Image.open(img_path)
split_row = image.width / ceil_w
split_colum = image.height / ceil_h
target_point_list = []
for target_point in target_point_obj:
target_point_list.append((target_point.x, target_point.y))
img_data = image.load() # 加载图片数据
for i in range(int(split_row)):
for j in range(int(split_colum)):
# 计算每个小图片的左上角坐
left = j * ceil_w
top = i * ceil_h
right = min(left + ceil_w, image.width) # 处理边界情况
bottom = min(top + ceil_h, image.height)
if (i, j) in target_point_list:
for y in range(top, bottom):
for x in range(left, right):
img_data[x, y] = new_color
# alternate point end and start point color
alternat_point_color(ceil_w, ceil_h, img_data, start_point, image)
alternat_point_color(ceil_w, ceil_h, img_data, end_point, image)
image.save('./path_in_org_img.png')
print('--------------path saved in org image!!!----------------')
if __name__ == "__main__":
# 创建地图
ceil_h = 15
ceil_w = 15
map2d = Map2D(ceil_h, ceil_w, r'D:\YJJY_Proj\path_select\road_select.png')
# map2d = split_image_and_compute_mean(r'D:\PycharmProjects\YJJY_Project\AStar_algorithm\test1.png', 17, 19)
# 设置起点和终点
map2d.obstacle_generate(0.1)
start_point = Point(136, 97)
end_point = Point(311, 221)
map2d.set_start_end(start_point, end_point)
# save org ciel image
array_to_image(map2d.data)
# 运行A*算法
start_time = time.time()
a_star = AStar(start_point, end_point, map2d)
path = a_star.find_path()
end_time = time.time()
# 打印结果
if path:
print("找到最佳路径:")
for point in path:
map2d.data[point.x][point.y] = "🟩"
map2d.export_image("result.png")
# save path in org image
path_plot_in_org_img(ceil_h, ceil_w, r'D:\YJJY_Proj\path_select\org_complete.jpg', path, (0, 255, 0), start_point,
end_point)
else:
print("未找到路径!")
# 打印运行时间
print("程序运行时间:", end_time - start_time, "秒")