探索寻路算法

87 阅读22分钟

我关于寻路算法的探索路径是:

  • 朴素的A star算法
  • 基于A star的第一次优化 —— 使用优先队列
  • 基于A star的第二次优化 —— 方向选择抗锯齿
  • 基于A star的第三次优化(伪) —— 增大行动步长
  • 基于A star的第四次优化 —— JPS

以下报告分为两部分,首先是关于以上探索流程的代码分析和结果呈现,其次是过程中解决问题的记录

过程报告

  1. V1 基础版本

    • 代码:

      file = io.open("C:\Users\songlin\Desktop\Lua_expolrer\myLua\lua\map.bytes")
      local math = require("math")
      local os = require("os")
      ​
      AStarPlanner = {}
      AStarPlanner.__index = AStarPlanner
      ​
      function AStarPlanner:new(resolution, rr)
          setmetatable({}, AStarPlanner)
          self.resolution = resolution
          self.rr = rr
          self.min_x, self.min_y = 0, 0
          self.max_x, self.max_y = 0, 0
          self.obstacle_map = nil
          self.x_width, self.y_width = 0, 0
          self.motion = self:get_motion_model()
          self:useExitedMap()
          return self
      endfunction AStarPlanner:Node(x, y, cost, parent_index)
          local node = {x=x,y=y,cost=cost,parent_index=parent_index}
          return node
      endfunction AStarPlanner:planning(sx,sy,gx,gy)
          local start_node = self:Node(self:calc_xy_index(sx,self.min_x),
                                  self:calc_xy_index(sy,self.min_y), 0.0, -1)
          local goal_node = self:Node(self:calc_xy_index(gx, self.min_x),
                                  self:calc_xy_index(gy, self.min_y), 0.0, -1)
          local open_set, closed_set = {}, {}
          open_set[self:calc_grid_index(start_node)] = start_node
          local start_time = os.clock() 
          
          while true do
              if next(open_set) == nil then
                  print("Open set is empty..")
                  break
              end
      ​
              local c_id, current --lua连取最小值都没有,感觉是因为table元素太复杂,规则比较难写
              local min_cost = math.huge
              for id, node in pairs(open_set) do
                  local cost = node.cost + self:calc_heuristic(goal_node, node)
                  if cost < min_cost then
                      min_cost = cost
                      c_id, current = id, node
                  end
              end
              
              if current.x == goal_node.x and current.y == goal_node.y then
                  print("find goal!")
                  goal_node.parent_index = current.parent_index
                  goal_node.cost = current.cost
                  self:printTrace(current, closed_set)
                  break
              end
      ​
              open_set[c_id] = nil
              closed_set[c_id] = current
      ​
              for i, move in ipairs(self.motion) do
                  local node = self:Node(current.x + move[1], current.y + move[2], current.cost + move[3], c_id)
                  local n_id = self:calc_grid_index(node)
      ​
                  if not self:verify_node(node) then
                      goto continue
                  end
      ​
                  if closed_set[n_id] ~= nil then
                      goto continue
                  end
      ​
                  if open_set[n_id] == nil then
                      open_set[n_id] = node
                  else
                      if open_set[n_id].cost > node.cost then
                          open_set[n_id] = node
                      end
                  end
                  ::continue::
              end
          end
          local end_time = os.clock()
          print("time cost = ", end_time - start_time)
      endfunction AStarPlanner.calc_final_path(self, goal_node, closed_set)
          local rx, ry = {self.calc_grid_position(goal_node.x, self.min_x)}, {self.calc_grid_position(goal_node.y,self.min_y)}
          local parent_index = goal_node.parent_index
      ​
          while parent_index ~= -1 do
              local n = closed_set[parent_index]
              table.insert(rx,self:calc_grid_position(n.x,self.min_x))
              table.insert(ry,self:calc_grid_position(n.y,self.min_y))
              parent_index = n.parent_index
          end
          return rx, ry    
      endfunction AStarPlanner:calc_heuristic(n1, n2)
          local w = 1.0
          local d = w * math.sqrt((n1.x - n2.x)^2 + (n1.y - n2.y)^2)
          return d
      endfunction AStarPlanner.calc_grid_position(self, index, min_position)
          local pos = index * self.resolution + min_position
          return pos
      endfunction AStarPlanner.calc_xy_index(self, position, min_pos)
          return math.floor((position - min_pos) / self.resolution + 0.5)
      endfunction AStarPlanner.calc_grid_index(self, node)
          return (node.y - self.min_y) * self.x_width + (node.x - self.min_x)
      endfunction AStarPlanner.verify_node(self, node)
          if self.obstacle_map[node.x][node.y] then
              return true
          end
          return false
      end
      ​
      ​
      function AStarPlanner:get_motion_model()
          motion = {{1, 0, 1},
          {0, 1, 1},
          {-1, 0, 1},
          {0, -1, 1},
          {-1, -1, math.sqrt(2)},
          {-1, 1, math.sqrt(2)},
          {1, -1, math.sqrt(2)},
          {1, 1, math.sqrt(2)}}
      return motion
      endfunction AStarPlanner:printTrace(current, closed_set)
          local path = {}
          table.insert(path, current)
      ​
          while true do
              print(current.x, current.y, current.cost, current.parent_index)
              local fatherIndex = current.parent_index
              if fatherIndex == -1 then
                  break
              end
      ​
              local fatherNode = closed_set[fatherIndex]
              table.insert(path, 1, fatherNode)
              current = fatherNode
          end
      endfunction AStarPlanner:useExitedMap()
          local file = io.open("D:\ideaProject\fileTest\map.bytes", "rb")
          if not file then
              error("Could not open map file")
          end
          local header = file:read(4)
          local high = header:byte(1) * 256 + header:byte(2)
          local width = header:byte(3) * 256 + header:byte(4)
      ​
          self.min_x = 0
          self.max_x = width
          self.min_y = 0
          self.max_y = high
          self.x_width = width
          self.y_width = high
      ​
          self.obstacle_map = {}
          for i = 1, width do
              self.obstacle_map[i] = {}
              local temp = file:read(width)
              for j = 1, high do
                  if temp:byte(j) == 0 then
                      self.obstacle_map[i][j] = true
                  else
                      self.obstacle_map[i][j] = false
                  end
              end
          end
      endfunction main()
          
          local sx, sy, gx, gy = 400.0, 300.0, 600.0, 600.0
          local grid_size, robot_radius = 1.0, 1.0
      ​
          local a_star = AStarPlanner:new(grid_size, robot_radius)
          a_star:planning(sx, sy, gx, gy)
      ​
      end
      ​
      main()
      
    • A Star 算法简述

      相比于BFS,A Star算法通过启发函数,尝试找到指向目标节点的方向,以此在减少节点遍历程度。

      具体过程为,对于出发点周围的可达点,计算其f值(启发函数值,用来刻画从上一个点到终点,经过该点时的成本),将其加入探索点集合(open_set),从中选择f值最小的点,作为下一步的出发点。并将其加入以探索集合(close_set),防止绕圈;再从出发点迭代寻找。

      方法精妙的点在于,当发现某个新的点之前被探索过,但这次新计算的f值比之前的更低,本质上意味着新的这条路比之前的路更短,此时将该节点的parent属性更新,即完成了路径的选择。

      最后从goal_node,沿着每个节点的parent反向寻址,即可找到f值最低的整条路径。

    • 实验分析:

      • 目的:在下图中,从(400,300)到(600,600)找到避开障碍物的路径(起点和终点都非障碍物)

        image-20240709114538943.png

      • 结果:

        time cost = 10.153

        image-20240709150341103.png

        图片由python绘制,红线表示最终路径,蓝色区域为探测范围;绘图由另一个程序完成,10秒的计时并无绘图IO

  2. V2 引入优先队列

    • 部分代码

      function AStarPlanner:planning(sx,sy,gx,gy)
          local start_node = self:Node(self:calc_xy_index(sx,self.min_x),
                                  self:calc_xy_index(sy,self.min_y), 0.0, -1)
          local goal_node = self:Node(self:calc_xy_index(gx, self.min_x),
                                  self:calc_xy_index(gy, self.min_y), 0.0, -1)
          local open_set, closed_set = {}, {}
          local pq = PriorityQueue:new(goal_node)
          open_set[self:calc_grid_index(start_node)] = start_node
          pq:heappush(AStarPlanner:Outernode(self:calc_grid_index(start_node), start_node))
          local start_time = os.clock() 
          
          while true do
              if next(open_set) == nil then
                  print("Open set is empty..")
                  break
              end
      ​
              local c_id, current --lua连取最小值都没有,感觉是因为table元素太复杂,规则比较难写
              -- local min_cost = math.huge
              -- for id, node in pairs(open_set) do
              --     local cost = node.cost + self:calc_heuristic(goal_node, node)
              --     if cost < min_cost then
              --         min_cost = cost
              --         c_id, current = id, node
              --     end
              -- end
              local outerCurr = pq:heappop()
              c_id = outerCurr.c_id
              current = outerCurr.Node
              
              if current.x == goal_node.x and current.y == goal_node.y then
                  print("find goal!")
                  goal_node.parent_index = current.parent_index
                  goal_node.cost = current.cost
                  self:printTrace(current, closed_set)
                  break
              end
      ​
              open_set[c_id] = nil
              closed_set[c_id] = current
      ​
              for i, move in ipairs(self.motion) do
                  local node = self:Node(current.x + move[1], current.y + move[2], current.cost + move[3], c_id)
                  local n_id = self:calc_grid_index(node)
      ​
                  if not self:verify_node(node) then
                      goto continue
                  end
      ​
                  if closed_set[n_id] ~= nil then
                      goto continue
                  end
      ​
                  if open_set[n_id] == nil then
                      open_set[n_id] = node
                      pq:heappush(AStarPlanner:Outernode(n_id, node))
                  else
                      if open_set[n_id].cost > node.cost then
                          local old_node = open_set[n_id]
                          open_set[n_id] = node
                          pq:heapRemove(old_node)
                          pq:heappush(AStarPlanner:Outernode(n_id, node))
                      end
                  end
                  ::continue::
              end
          end
          local end_time = os.clock()
          print("time cost = ", end_time - start_time)
      end
      
    • 简述:

      二版本中引入消息队列优化open_set,由于open_set在程序中需要完成 根据k-v快速查找迅速取出最小值,所以我选择用空间换取时间——保留原始k-v的table,平行添加一个优先队列;

      由于Lua没有原生的k-v类数据结构(类似Java中Map.Entity),于是手动封装Outernode类型,作为优先队列的元素。

      以下是优先队列的实现代码

      -- 引入 math 和 table 库
      local math = require('math')
      local table = require('table')
      ​
      -- 定义一个优先队列类
      local PriorityQueue = {}
      ​
      -- 初始化优先队列
      function PriorityQueue:new(Node)
          local obj = {}
          setmetatable(obj, self)
          self.__index = self
          obj.queue = {} -- 存储队列元素
          obj.goal_node = Node -- 目标节点
          return obj
      end-- 调整堆
      function PriorityQueue:heapAdjust(index, endIndex)
          local left = index * 2
          local right = left + 1
          while left <= endIndex do
              local maxIndex = index
              if self.queue[left].Node.cost + self:calcHeuristic(self.goal_node, self.queue[left].Node) <
                 self.queue[maxIndex].Node.cost + self:calcHeuristic(self.goal_node, self.queue[maxIndex].Node) then
                  maxIndex = left
              end
              if right <= endIndex and self.queue[right].Node.cost + self:calcHeuristic(self.goal_node, self.queue[right].Node) <
                 self.queue[maxIndex].Node.cost + self:calcHeuristic(self.goal_node, self.queue[maxIndex].Node) then
                  maxIndex = right
              end
              if index == maxIndex then
                  break
              end
              self.queue[index], self.queue[maxIndex] = self.queue[maxIndex], self.queue[index]
              index = maxIndex
              left = index * 2
              right = left + 1
          end
      end-- 构建二叉堆
      function PriorityQueue:heapify()
          local size = #self.queue
          for i = ((size - 2) // 2), -1, -1 do
              self:heapAdjust(i, size - 1)
          end
      end-- 入队操作
      function PriorityQueue:heappush(value)
          -- table.insert(self.queue, value)
          local size = #self.queue
          self.queue[size+1] = value
          local i = size + 1
          while i // 2 > 0 do --i从1开始
              local curRoot = i // 2
              if self.queue[curRoot].Node.cost + self:calcHeuristic(self.goal_node, self.queue[curRoot].Node) <
                 value.Node.cost + self:calcHeuristic(self.goal_node, value.Node) then
                  break
              end
              self.queue[i] = self.queue[curRoot]
              i = curRoot
          end
          self.queue[i] = value
      end-- 出队操作
      function PriorityQueue:heappop()
          local size = #self.queue
          self.queue[1], self.queue[size] = self.queue[size], self.queue[1]
          -- local top = table.remove(self.queue, size)
          local top = self.queue[#self.queue]
          self.queue[#self.queue] = nil
          if #self.queue > 0 then
              self:heapAdjust(1, #self.queue)
          end
          return top
      endfunction PriorityQueue:heapRemove(value)
          local size = #self.queue
          if size == 0 then
              return false
          end
          local index = -1
          for i = 1, size do
              if self.queue[i].Node == value then
                  index = i
                  break
              end
          end
          if index == -1 then
              return false  -- 元素不在堆中
          end
          self.queue[index], self.queue[size] = self.queue[size], self.queue[index]
          local removed = self.queue[#self.queue]
          self.queue[#self.queue] = nil
          if index < size - 1 then  -- 如果要删除的元素不是最后一个元素,调整堆
              self:heapAdjust(index, size - 2)
          end
          return removed
      end-- 计算启发式函数
      function PriorityQueue:calcHeuristic(node1, node2)
          local w = 1.0 -- 权重
          local d = w * math.sqrt((node1.x - node2.x) * (node1.x - node2.x) + 
                                  (node1.y - node2.y) * (node1.y - node2.y))
          return d
      endreturn PriorityQueue
      
    • 实验分析:

      1. 结果:

        time cost = 1.713

        可见对open_set的排序是V1中最耗时的部分

        image-20240709142402973.png

  3. V3 抗锯齿

    • 部分代码(n = 3)

      function PriorityQueue:heapPopSmooth(current_node, last_direction)
          local size = #self.queue
          local target = 1
          if size < 3 then
              return self:heappop()
          end
          for i = 2, 3 do
              local node = self.queue[i].Node
              if node.x - current_node.x == last_direction.x and
                  node.y - current_node.y == last_direction.y then
                      target = i
                      break
                  end        
          end
          self.queue[target], self.queue[size] = self.queue[size], self.queue[target]
          local top = self.queue[size]
          self.queue[size] = nil
          if #self.queue > 0 then
              self:heapAdjust(1, #self.queue)
          end
          return top
      end
      
    • 简述

      原先基于优先队列选择f值最小的方案没有考虑方向,如果想要方向上尽量和之前的一致,下面有两种思路:

      1. 优先队列中 并不优先选f值最小的项,而是在几个f值够小的项中,选择一个方向一致的
      2. 增大步长( V4 中说明)

      在第一种思路下,考虑优先队列前n个元素,经过测试,n越高越好,并且n = 3 和 n = 15耗时大致相同,但效果有点差距

      n == 3

      image-20240709151027309.png

      n == 15

      image-20240709151941592.png

  4. V4 增大步长

    • 部分代码

      motion = {{2, 0, 1},
          {0, 2, 1},
          {-2, 0, 1},
          {0, -2, 1},
          {-2, -2, math.sqrt(2)},
          {-2, 2, math.sqrt(2)},
          {2, -2, math.sqrt(2)},
          {2, 2, math.sqrt(2)}}
      
    • 简述:

      增加步长有逆天的优化效果,但这种效果其实是假象

      image-20240709152352703.png

      甚至当步长为4个单位时,几乎要沿直线前进了

      image-20240709152231876.png

      但其实很明显可以看到,随着步长越来越长,原本是障碍物的点都被如履平地了;所以绕的路少了,加入open_set的点少了,时间自然减短很多,但可以并不适合这张地图

      我们的地图里,小障碍物很多,于是增大步长并不适用

      image-20240709153202227.png

      但假如某些地图中,障碍物全是大山大河,使用增加步长想必是最快的优化

  5. V5 JPS

    • 完整代码

      local sqrt = math.sqrt-- Node类的定义
      Node = {}
      Node.__index = Node
      ​
      function Node:new(grid_pos, g, h, parent)
          local obj = setmetatable({}, Node)
          obj.grid_pos = grid_pos
          obj.g = g
          obj.h = h
          obj.parent = parent
          obj.f = g + h
          return obj
      end-- JPS类的定义
      JPS = {}
      JPS.__index = JPS
      ​
      function JPS:new(width, height)
          local obj = setmetatable({}, JPS)
          obj.start_grid_pos = {0, 0}
          obj.goal_grid_pos = {0, 0}
          obj.width = width
          obj.height = height
          obj.open = {}  -- pos:node
          obj.close = {}
          obj.motion_directions = {{1, 0}, {0, 1}, {0, -1}, {-1, 0}, {1, 1}, {1, -1}, {-1, 1}, {-1, -1}}
          return obj
      endfunction JPS:run(start_grid_pos, goal_grid_pos)
          self.start_grid_pos = start_grid_pos
          self.goal_grid_pos = goal_grid_pos
          local start_node = Node:new(start_grid_pos, 0, self:getH(self.start_grid_pos, self.goal_grid_pos))
          self.open[self:positionSerialize(start_node.grid_pos)] = start_node
          local start_time = os.time()
          
          while true do
              -- 寻路失败,结束循环
              if next(self.open) == nil then
                  print("未找到路径")
                  break
              end
              
              local current_node = nil
              local min_f = math.huge
              for _, node in pairs(self.open) do
                  if node.f < min_f then
                      min_f = node.f
                      current_node = node
                  end
              end
              
              -- 找到路径, 返回结果
              if current_node.grid_pos[1] == self.goal_grid_pos[1] and current_node.grid_pos[2] == self.goal_grid_pos[2] then
                  local path = self:findPath(current_node)
                  local end_time = os.time()
                  print("time cost = ", os.difftime(end_time, start_time))
                  return path
              end
              -- print(current_node.grid_pos[1], current_node.grid_pos[2])
              self:extendNode(current_node)
              -- print("length", ACCUMULATELENGTH(self.open))
      ​
              self.close[current_node.grid_pos] = current_node
              -- print("before length", ACCUMULATELENGTH(self.open))
              self.open[self:positionSerialize(current_node.grid_pos)] = nil
              -- print("after length", ACCUMULATELENGTH(self.open))
      ​
          end
      endfunction PRINTTABLE(table)
          for k, v in pairs(table) do
              print("k", k)
              print("v", v[1], v[2])    
          end
      endfunction ACCUMULATELENGTH(table)
          local index = 1
          for key, value in pairs(table) do
              index = index + 1
          end
          return index
      endfunction JPS:positionSerialize(jp)
          local x = jp[1]
          local y = jp[2]
      ​
          local serialization = x * 1000 + y
          return serialization
      endfunction JPS:extendNode(current_node)
          local neighbours = self:getPruneNeighbours(current_node)
          -- PRINTTABLE(neighbours)
          for _, n in ipairs(neighbours) do
              -- print("now", n[1], n[2])
              -- print("pre", current_node.grid_pos[1], current_node.grid_pos[2])
              local jp = self:getJumpNode(n, current_node.grid_pos)
              if jp then     
                  -- print("jp", jp[1], jp[2])                
                  if self.close[jp] then
                      goto continue
                  end
                  local new_node = Node:new(jp, current_node.g + self:getG(jp, current_node.grid_pos),
                                          self:getH(jp, self.goal_grid_pos), current_node)
                  local serialized = self:positionSerialize(jp)
                  if self.open[serialized] then
                      if new_node.f < self.open[serialized].f then
                          self.open[serialized].parent = current_node
                          self.open[serialized].f = new_node.f
                          -- print("update", new_node.grid_pos[1], new_node.grid_pos[2])
                      end
                  else
                      self.open[serialized] = new_node
                      -- print("add", new_node.grid_pos[1], new_node.grid_pos[2])
                  end
              end
              ::continue::
          end
      endfunction JPS:getJumpNode(now, pre)
          -- print("now", now[1], now[2])
          -- print("pre", pre[1], pre[2])
          
          local x_direction = (now[1] - pre[1] ~= 0) and (now[1] - pre[1]) / math.abs(now[1] - pre[1]) or 0
          local y_direction = (now[2] - pre[2] ~= 0) and (now[2] - pre[2]) / math.abs(now[2] - pre[2]) or 0
          
          if now[1] == self.goal_grid_pos[1] and now[2] == self.goal_grid_pos[2] then
              return now
          end
          
          if self:hasForceNeighbours(now, pre) then
              return now
          end
          
          if math.abs(x_direction) + math.abs(y_direction) == 2 then
              if self:getJumpNode({now[1] + x_direction, now[2]}, now) or self:getJumpNode({now[1], now[2] + y_direction}, now) then
                  return now
              end
          end
          
          if self:isPass(now[1] + x_direction, now[2] + y_direction) then
              local jp = self:getJumpNode({now[1] + x_direction, now[2] + y_direction}, now)
              if jp then
                  return jp
              end
          end
          
          return nil
      endfunction JPS:hasForceNeighbours(now, pre)
          local x_direction = now[1] - pre[1]
          local y_direction = now[2] - pre[2]
          
          if math.abs(x_direction) + math.abs(y_direction) == 1 then
              if math.abs(x_direction) == 1 then
                  if (self:isPass(now[1] + x_direction, now[2] + 1) and not self:isPass(now[1], now[2] + 1)) or
                      (self:isPass(now[1] + x_direction, now[2] - 1) and not self:isPass(now[1], now[2] - 1)) then
                      return true
                  else
                      return false
                  end
              elseif math.abs(y_direction) == 1 then
                  if (self:isPass(now[1] + 1, now[2] + y_direction) and not self:isPass(now[1] + 1, now[2])) or
                      (self:isPass(now[1] - 1, now[2] + y_direction) and not self:isPass(now[1] - 1, now[2])) then
                      return true
                  else
                      return false
                  end
              else
                  error("错误,直线移动中只能水平或垂直移动!")
              end
          elseif math.abs(x_direction) + math.abs(y_direction) == 2 then
              if (self:isPass(now[1] + x_direction, now[2] - y_direction) and not self:isPass(now[1], now[2] - y_direction)) or
                  (self:isPass(now[1] - x_direction, now[2] + y_direction) and not self:isPass(now[1] - x_direction, now[2])) then
                  return true
              else
                  return false
              end
          else
              error("错误,只能直线移动或斜线移动!")
          end
      endfunction JPS:getH(current, goal, func)
          func = func or "Euclidean"
          local current_x, current_y = current[1], current[2]
          local goal_x, goal_y = goal[1], goal[2]
          
          if func == "Manhattan" then
              return math.abs(current_x - goal_x) + math.abs(current_y - goal_y)
          elseif func == "Euclidean" then
              return sqrt((current_x - goal_x)^2 + (current_y - goal_y)^2)
          elseif func == "Chebyshev" then
              return math.max(math.abs(current_x - goal_x), math.abs(current_y - goal_y))
          else
              error("错误,不支持该启发函数。目前支持:Manhattan、Euclidean(默认)、Chebyshev。")
          end
      endfunction JPS:getG(pos1, pos2)
          return sqrt((pos1[1] - pos2[1])^2 + (pos1[2] - pos2[2])^2)
      endfunction JPS:getPruneNeighbours(current_node)
          local prune_neighbours = {}
          
          if current_node.parent then
              local motion_x = (current_node.grid_pos[1] - current_node.parent.grid_pos[1] ~= 0) and 
                                (current_node.grid_pos[1] - current_node.parent.grid_pos[1]) / 
                                math.abs(current_node.grid_pos[1] - current_node.parent.grid_pos[1]) or 0
              local motion_y = (current_node.grid_pos[2] - current_node.parent.grid_pos[2] ~= 0) and 
                                (current_node.grid_pos[2] - current_node.parent.grid_pos[2]) / 
                                math.abs(current_node.grid_pos[2] - current_node.parent.grid_pos[2]) or 0
              
              if math.abs(motion_x) + math.abs(motion_y) == 1 then
                  -- 自然邻居
                  if self:isPass(current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] + motion_y) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] + motion_y})
                  end
                  -- 强迫邻居
                  if motion_x == 0 then
                      if not self:isPass(current_node.grid_pos[1] + 1, current_node.grid_pos[2]) and 
                         self:isPass(current_node.grid_pos[1] + 1, current_node.grid_pos[2] + motion_y) then
                          table.insert(prune_neighbours, {current_node.grid_pos[1] + 1, current_node.grid_pos[2] + motion_y})
                      end
                      if not self:isPass(current_node.grid_pos[1] - 1, current_node.grid_pos[2]) and 
                         self:isPass(current_node.grid_pos[1] - 1, current_node.grid_pos[2] + motion_y) then
                          table.insert(prune_neighbours, {current_node.grid_pos[1] - 1, current_node.grid_pos[2] + motion_y})
                      end
                  else
                      if not self:isPass(current_node.grid_pos[1], current_node.grid_pos[2] + 1) and 
                         self:isPass(current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] + 1) then
                          table.insert(prune_neighbours, {current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] + 1})
                      end
                      if not self:isPass(current_node.grid_pos[1], current_node.grid_pos[2] - 1) and 
                         self:isPass(current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] - 1) then
                          table.insert(prune_neighbours, {current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] - 1})
                      end
                  end
              elseif math.abs(motion_x) + math.abs(motion_y) == 2 then
                  -- 自然邻居
                  if self:isPass(current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] + motion_y) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] + motion_y})
                  end
                  if self:isPass(current_node.grid_pos[1] + motion_x, current_node.grid_pos[2]) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1] + motion_x, current_node.grid_pos[2]})
                  end
                  if self:isPass(current_node.grid_pos[1], current_node.grid_pos[2] + motion_y) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1], current_node.grid_pos[2] + motion_y})
                  end
                  -- 强迫邻居
                  if not self:isPass(current_node.grid_pos[1] - motion_x, current_node.grid_pos[2]) and 
                     self:isPass(current_node.grid_pos[1] - motion_x, current_node.grid_pos[2] + motion_y) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1] - motion_x, current_node.grid_pos[2] + motion_y})
                  end
                  if not self:isPass(current_node.grid_pos[1], current_node.grid_pos[2] - motion_y) and 
                     self:isPass(current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] - motion_y) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1] + motion_x, current_node.grid_pos[2] - motion_y})
                  end
              else
                  error("错误,只能对角线和直线行走!")
              end
          else
              for _, dir in ipairs(self.motion_directions) do
                  if self:isPass(current_node.grid_pos[1] + dir[1], current_node.grid_pos[2] + dir[2]) then
                      table.insert(prune_neighbours, {current_node.grid_pos[1] + dir[1], current_node.grid_pos[2] + dir[2]})
                  end
              end
          end
          
          return prune_neighbours
      endfunction JPS:isPass(grid_x, grid_y)
          if grid_x >= 0 and grid_x < self.width and grid_y >= 0 and grid_y < self.height then
              if obstacle_map[grid_x + 1][grid_y + 1] ~= 1 or {grid_x + 1, grid_y + 1} == self.goal_grid_pos then
                  return true
              else
                  return false
              end
          else
              return false
          end
      endfunction JPS:findPath(node)
          local path_x = {node.grid_pos[1]}
          local path_y = {node.grid_pos[2]}
          while node.parent do
              -- print(node.grid_pos[0], node.grid_pos[1])
              node = node.parent
              table.insert(path_x, node.grid_pos[1])
              table.insert(path_y, node.grid_pos[2])
          end
          return {path_x, path_y}
      end-- 主程序
      local file = io.open("D:\ideaProject\fileTest\map.bytes", 'rb')
      local header = file:read(4)
      local high = header:byte(1) * 256 + header:byte(2)
      local width = header:byte(3) * 256 + header:byte(4)
      ​
      print("地图宽度", width)
      print("地图高度", high)
      ​
      obstacle_map = {}
      ​
      for i = 1, width do
          obstacle_map[i] = {}
          local temp = file:read(width)
          for j = 1, high do
              if temp:byte(j) == 0 then
                  obstacle_map[i][j] = 0
              else
                  obstacle_map[i][j] = 1
              end
          end
      endlocal start = {400, 300}
      local goal = {600, 600}
      local jps = JPS:new(width, high)
      local path = jps:run(start, goal)
      
    • 简述

      JPS可以说是A* + 剪枝,就是将一些显然的中间节点跳过,不添加到open_set中

      感性的理解JPS:在起点到终点之间,不能直线通过是因为要避开障碍物,而从起点到某个障碍物可以直线到达。jps通过在障碍物附近寻找forced Neighbor,继而确定探索过程能否直接跳到跳点。

    • 结果堪称银弹

      image-20240709160852425.png

      time cost = 14.0

      但是很奇怪的一点在于,python中运行耗时两秒,其中包括部分绘图IO,但翻译成lua耗时14秒,具体原因再探明

问题记录

  1. 相同内容的table并不是同一块内存,而table作为hashmap时,key是否相同是直接看这个对象的内存地址的

    在open_set中,原本尝试key直接用点的坐标 { x = 500, y = 1000 } 形式,但是出现程序无法终止的情况

    具体原因如下

    • 程序无法终止,是因为open_set中选择的下一个点在路径上徘徊
    • 徘徊的原因是:出现了比目标点f值更小点,且这个点在宏观路径上是往回走的
    • 理论上f值由g和h组成,h值越接近终点越小,g一直差不多,所以f应该越走越小;出现前面的某个点f值比后面点f值小的原因是:前面的没删掉
    • 通过观测每轮中,open_set的大小,可以推测出确实是因为删除失败
    • 而删除失败的原因是:通过新构造的 { x = 500, y = 1000 } ,无法对应到open_set中原本的key。

    解决方案是:将坐标的table序列化成数字

    function JPS:positionSerialize(jp)
        local x = jp[1]
        local y = jp[2]
    ​
        local serialization = x * 10000 + y
        return serialization
    end
    

附Python代码

  1. A*算法
"""

A* grid planning

author: Atsushi Sakai(@Atsushi_twi)
        Nikos Kanargias (nkana@tee.gr)

See Wikipedia article (https://en.wikipedia.org/wiki/A*_search_algorithm)

"""

import math
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np

from org.example.findWays.visionControl.v3_smooth import PriorityQueue

show_animation = True


class Node:
    """定义搜索区域节点类,每个Node都包含坐标x和y, 移动代价cost和父节点索引。
    """

    def __init__(self, x, y, cost, parent_index):
        self.x = x  # index of grid
        self.y = y  # index of grid
        self.cost = cost
        self.parent_index = parent_index

    def __str__(self):
        return str(self.x) + "," + str(self.y) + "," + str(
            self.cost) + "," + str(self.parent_index)


class OuterNode:
    def __init__(self, c_id, node):
        self.c_id = c_id
        self.node = node


class AStarPlanner:

    def __init__(self, gx, gy, resolution, rr):
        """
        Initialize grid map for a star planning

        ox: x position list of Obstacles [m]
        oy: y position list of Obstacles [m]
        resolution: grid resolution [m],地图的像素
        rr: robot radius[m]
        """

        self.resolution = resolution
        self.rr = rr
        self.min_x, self.min_y = 0, 0
        self.max_x, self.max_y = 0, 0
        self.obstacle_map = None
        self.x_width, self.y_width = 0, 0
        self.motion = self.get_motion_model()
        # self.calc_obstacle_map(ox, oy)
        self.useExitedMap()
        if not self.obstacle_map[gy][gx]:
            print("终点不可达")

    def planning(self, sx, sy, gx, gy):
        """
        A star path search
        输入起始点和目标点的坐标(sx,sy)和(gx,gy),
        最终输出的结果是路径包含的点的坐标集合rx和ry。
        input:
            s_x: start x position [m]
            s_y: start y position [m]
            gx: goal x position [m]
            gy: goal y position [m]

        output:
            rx: x position list of the final path
            ry: y position list of the final path
        """

        start_node = Node(self.calc_xy_index(sx, self.min_x),
                          self.calc_xy_index(sy, self.min_y), 0.0, -1)
        goal_node = Node(self.calc_xy_index(gx, self.min_x),
                         self.calc_xy_index(gy, self.min_y), 0.0, -1)

        open_set, closed_set = dict(), dict()
        open_bak = []
        pq = PriorityQueue.HeapqM()
        open_set[self.calc_grid_index(start_node)] = start_node
        pq.heappush(open_bak, OuterNode(self.calc_grid_index(start_node), start_node), goal_node)
        start_time = datetime.now()
        current = start_node
        last_direction = {"x": 0, "y": 0}

        while 1:
            # print("length", len(open_set))
            if len(open_set) == 0:
                print("Open set is empty..")
                break

            # c_id = min(
            #     open_set,
            #     key=lambda o: open_set[o].cost + self.calc_heuristic(goal_node, open_set[o]))
            c_id = pq.heapPopSmooth(open_bak, goal_node, current, last_direction).c_id
            nextN = open_set[c_id]
            last_direction["x"] = nextN.x - current.x
            last_direction["y"] = nextN.y - current.y
            current = nextN

            # show graph
            if show_animation:  # pragma: no cover
                plt.plot(self.calc_grid_position(current.x, self.min_x),
                         self.calc_grid_position(current.y, self.min_y), "xc")
                # for stopping simulation with the esc key.
                plt.gcf().canvas.mpl_connect('key_release_event',
                                             lambda event: [exit(0) if event.key == 'escape' else None])

            # 通过追踪当前位置current.x和current.y来动态展示路径寻找
            if current.x == goal_node.x and current.y == goal_node.y:
                print("Find goal")
                goal_node.parent_index = current.parent_index
                goal_node.cost = current.cost
                self.printTrace(current, closed_set)
                # print(path)
                break

            # Remove the item from the open set
            del open_set[c_id]

            # Add it to the closed set
            closed_set[c_id] = current

            # expand_grid search grid based on motion model
            for i, _ in enumerate(self.motion):
                node = Node(current.x + self.motion[i][0],
                            current.y + self.motion[i][1],
                            current.cost + self.motion[i][2], c_id)
                n_id = self.calc_grid_index(node)

                # If the node is not safe, do nothing
                if not self.verify_node(node):
                    continue

                if n_id in closed_set:
                    continue

                if n_id not in open_set:
                    open_set[n_id] = node  # discovered a new node
                    pq.heappush(open_bak, OuterNode(n_id, node), goal_node)
                else:
                    if open_set[n_id].cost > node.cost:
                        # This path is the best until now. record it
                        old_node = open_set[n_id]
                        open_set[n_id] = node
                        pq.heapRemove(open_bak, old_node, goal_node)
                        pq.heappush(open_bak, OuterNode(n_id, node), goal_node)

        rx, ry = self.calc_final_path(goal_node, closed_set)
        end_time = datetime.now()
        print("time cost = ", end_time - start_time)
        return rx, ry

    def calc_final_path(self, goal_node, closed_set):
        # generate final course
        rx, ry = [self.calc_grid_position(goal_node.x, self.min_x)], [
            self.calc_grid_position(goal_node.y, self.min_y)]
        parent_index = goal_node.parent_index
        while parent_index != -1:
            n = closed_set[parent_index]
            rx.append(self.calc_grid_position(n.x, self.min_x))
            ry.append(self.calc_grid_position(n.y, self.min_y))
            parent_index = n.parent_index

        return rx, ry

    @staticmethod
    def calc_heuristic(n1, n2):
        """计算启发函数

        Args:
            n1 (_type_): _description_
            n2 (_type_): _description_

        Returns:
            _type_: _description_
        """
        w = 1.0  # weight of heuristic
        d = w * math.hypot(n1.x - n2.x, n1.y - n2.y)
        return d

    def calc_grid_position(self, index, min_position):
        """
        calc grid position

        :param index:
        :param min_position:
        :return:
        """
        pos = index * self.resolution + min_position
        return pos

    def calc_xy_index(self, position, min_pos):
        return round((position - min_pos) / self.resolution)

    def calc_grid_index(self, node):
        return (node.y - self.min_y) * 1000 + (node.x - self.min_x)

    def verify_node(self, node):
        # collision check

        if self.obstacle_map[node.x][node.y]:
            return True

        return False

    def calc_obstacle_map(self, ox, oy):

        self.min_x = round(min(ox))
        self.min_y = round(min(oy))
        self.max_x = round(max(ox))
        self.max_y = round(max(oy))
        print("min_x:", self.min_x)
        print("min_y:", self.min_y)
        print("max_x:", self.max_x)
        print("max_y:", self.max_y)

        self.x_width = round((self.max_x - self.min_x) / self.resolution)
        self.y_width = round((self.max_y - self.min_y) / self.resolution)
        print("x_width:", self.x_width)
        print("y_width:", self.y_width)

        # obstacle map generation
        self.obstacle_map = [[False for _ in range(self.y_width)]
                             for _ in range(self.x_width)]
        for ix in range(self.x_width):
            x = self.calc_grid_position(ix, self.min_x)
            for iy in range(self.y_width):
                y = self.calc_grid_position(iy, self.min_y)
                for iox, ioy in zip(ox, oy):
                    d = math.hypot(iox - x, ioy - y)
                    if d <= self.rr:
                        self.obstacle_map[ix][iy] = True
                        break

    @staticmethod
    def get_motion_model():
        # dx, dy, cost
        motion = [[1, 0, 1],
                  [0, 1, 1],
                  [-1, 0, 1],
                  [0, -1, 1],
                  [-1, -1, math.sqrt(2)],
                  [-1, 1, math.sqrt(2)],
                  [1, -1, math.sqrt(2)],
                  [1, 1, math.sqrt(2)]]

        return motion

    def useExitedMap(self):
        with open('../map.bytes', 'rb') as file:

            width = int.from_bytes(file.read(2), byteorder='big')
            height = int.from_bytes(file.read(2), byteorder='big')
            print("地图宽度", width)
            print("地图高度", height)
            self.min_x = 0
            self.max_x = width
            self.min_y = 0
            self.max_y = height
            self.x_width = width
            self.y_width = height
            content = bytearray()
            while True:
                chunk = file.read(1)
                if not chunk:
                    break
                content.extend(chunk)

            self.obstacle_map = [[False for _ in range(width)] for _ in range(height)]

            index = 0
            for row in range(height):
                for col in range(width):
                    byte_value = content[index]
                    if byte_value == 0:
                        self.obstacle_map[row][col] = True
                    else:
                        self.obstacle_map[row][col] = False
                    index += 1

            # self.draw_image(self.obstacle_map)

    @staticmethod
    def printTrace(current: Node, closed_sset):
        path = []
        path.insert(0, current)
        while 1:
            print(current)
            fatherIndex = current.parent_index
            if fatherIndex == -1:
                break
            fatherNode = closed_sset[fatherIndex]
            path.insert(0, fatherNode)
            current = fatherNode
        # return path

    @staticmethod
    def draw_image(obstacle_map):
        fig, ax = plt.subplots()

        image_array = np.array(obstacle_map).astype(np.uint8)
        ax.imshow(image_array, cmap='Reds_r')  # 使用反转的红色色阶

        plt.show()


def main():
    print(__file__ + " start!!")

    # start and goal position
    sx = 400  # [m]
    sy = 300  # [m]
    gx = 600  # [m]
    gy = 600  # [m]
    grid_size = 1.0  # [m]
    robot_radius = 1.0  # [m]

    # set obstacle positions

    if show_animation:  # pragma: no cover
        # plt.plot(ox, oy, ".k")
        plt.plot(sx, sy, "og")
        plt.plot(gx, gy, "xb")
        plt.grid(True)
        plt.axis("equal")

    a_star = AStarPlanner(gx, gy, grid_size, robot_radius)
    rx, ry = a_star.planning(sx, sy, gx, gy)

    if show_animation:  # pragma: no cover
        plt.plot(rx, ry, "-r")
        plt.pause(0.001)
        plt.show()


if __name__ == '__main__':
    main()
import math

from org.example.findWays.visionControl.v3_smooth.v3_smooth import OuterNode, Node


class HeapqM:
    # 堆调整方法:调整为大顶堆
    def heapAdjust(self, nums: list, index: int, end: int, goal_node: Node):
        left = index * 2 + 1
        right = left + 1
        while left <= end:
            # 当前节点为非叶子结点
            max_index = index
            if (nums[left].node.cost + self.calc_heuristic(goal_node, nums[left].node) <
                    nums[max_index].node.cost + self.calc_heuristic(goal_node, nums[max_index].node)):
                max_index = left
            if (right <= end and nums[right].node.cost + self.calc_heuristic(goal_node, nums[right].node) <
                    nums[max_index].node.cost + self.calc_heuristic(goal_node, nums[max_index].node)):
                max_index = right
            if index == max_index:
                # 如果不用交换,则说明已经交换结束
                break
            nums[index], nums[max_index] = nums[max_index], nums[index]
            # 继续调整子树
            index = max_index
            left = index * 2 + 1
            right = left + 1

    # 将数组构建为二叉堆
    def heapify(self, nums: list, goal_node):
        size = len(nums)
        # (size - 2) // 2 是最后一个非叶节点,叶节点不用调整
        for i in range((size - 2) // 2, -1, -1):
            # 调用调整堆函数
            self.heapAdjust(nums, i, size - 1, goal_node)

    # 入队操作
    def heappush(self, nums: list, value: OuterNode, goal_node):
        nums.append(value)
        size = len(nums)
        i = size - 1
        # 寻找插入位置
        while (i - 1) // 2 >= 0:
            cur_root = (i - 1) // 2
            # value 小于当前根节点,则插入到当前位置
            if (nums[cur_root].node.cost + self.calc_heuristic(goal_node, nums[cur_root].node) <
                    value.node.cost + self.calc_heuristic(goal_node, value.node)):
                break
            # 继续向上查找
            nums[i] = nums[cur_root]
            i = cur_root
        # 找到插入位置或者到达根位置,将其插入
        nums[i] = value

    # 出队操作
    def heappop(self, nums: list, goal_node) -> int:
        size = len(nums)
        nums[0], nums[-1] = nums[-1], nums[0]
        # 得到最大值(堆顶元素)然后调整堆
        top = nums.pop()
        if size > 0:
            self.heapAdjust(nums, 0, size - 2, goal_node)

        return top

    def heapPopSmooth(self, nums: list, goal_node, current_node, last_direction):
        target = 0  # 默认使用最小的,下面的操作在各个最小值中丝滑的
        size = len(nums)
        for i in range(min(15, size)):
            node = nums[i].node
            if (node.x - current_node.x == last_direction["x"] and
                    node.y - current_node.y == last_direction["y"]):
                nums[i], nums[-1] = nums[-1], nums[i]
                target = i
                # 得到最大值(堆顶元素)然后调整堆
                break
        top = nums.pop(target)

        if size > 0:
            self.heapAdjust(nums, 0, size - 2, goal_node)
        return top

    # 升序堆排序
    def heapSort(self, nums: list, goal_node):
        self.heapify(nums)
        size = len(nums)
        for i in range(size):
            nums[0], nums[size - i - 1] = nums[size - i - 1], nums[0]
            self.heapAdjust(nums, 0, size - i - 2, goal_node)
        return nums

    def heapRemove(self, nums: list, value: Node, goal_node):
        size = len(nums)
        if size == 0:
            return False
        # 找到要删除的元素
        index = -1
        for i in range(size):
            if nums[i].node == value:
                index = i
                break
        if index == -1:
            return False  # 元素不在堆中

        # 将要删除的元素与最后一个元素交换
        nums[index], nums[-1] = nums[-1], nums[index]
        removed = nums.pop()  # 移除末尾元素

        if index < size - 1:  # 如果要删除的元素不是最后一个元素,调整堆
            # 调整堆
            self.heapAdjust(nums, index, size - 2, goal_node)
            # # 如果需要向上调整
            # parent_index = (index - 1) // 2
            # if parent_index >= 0 and (nums[index].node.cost + self.calc_heuristic(goal_node, nums[index].node) <
            #                           nums[parent_index].node.cost + self.calc_heuristic(goal_node,
            #                                                                              nums[parent_index].node)):
            #     self.heappush(nums, nums[index], goal_node)

        return removed

    @staticmethod
    def calc_heuristic(n1, n2):
        """计算启发函数

        Args:
            n1 (_type_): _description_
            n2 (_type_): _description_

        Returns:
            _type_: _description_
        """
        w = 1.0  # weight of heuristic
        d = w * math.hypot(n1.x - n2.x, n1.y - n2.y)
        # d = n1.x - n2.x + n1.y - n2.y
        return d

优先队列

import math

from org.example.findWays.visionControl.v3_smooth.v3_smooth import OuterNode, Node


class HeapqM:
    # 堆调整方法:调整为大顶堆
    def heapAdjust(self, nums: list, index: int, end: int, goal_node: Node):
        left = index * 2 + 1
        right = left + 1
        while left <= end:
            # 当前节点为非叶子结点
            max_index = index
            if (nums[left].node.cost + self.calc_heuristic(goal_node, nums[left].node) <
                    nums[max_index].node.cost + self.calc_heuristic(goal_node, nums[max_index].node)):
                max_index = left
            if (right <= end and nums[right].node.cost + self.calc_heuristic(goal_node, nums[right].node) <
                    nums[max_index].node.cost + self.calc_heuristic(goal_node, nums[max_index].node)):
                max_index = right
            if index == max_index:
                # 如果不用交换,则说明已经交换结束
                break
            nums[index], nums[max_index] = nums[max_index], nums[index]
            # 继续调整子树
            index = max_index
            left = index * 2 + 1
            right = left + 1

    # 将数组构建为二叉堆
    def heapify(self, nums: list, goal_node):
        size = len(nums)
        # (size - 2) // 2 是最后一个非叶节点,叶节点不用调整
        for i in range((size - 2) // 2, -1, -1):
            # 调用调整堆函数
            self.heapAdjust(nums, i, size - 1, goal_node)

    # 入队操作
    def heappush(self, nums: list, value: OuterNode, goal_node):
        nums.append(value)
        size = len(nums)
        i = size - 1
        # 寻找插入位置
        while (i - 1) // 2 >= 0:
            cur_root = (i - 1) // 2
            # value 小于当前根节点,则插入到当前位置
            if (nums[cur_root].node.cost + self.calc_heuristic(goal_node, nums[cur_root].node) <
                    value.node.cost + self.calc_heuristic(goal_node, value.node)):
                break
            # 继续向上查找
            nums[i] = nums[cur_root]
            i = cur_root
        # 找到插入位置或者到达根位置,将其插入
        nums[i] = value

    # 出队操作
    def heappop(self, nums: list, goal_node) -> int:
        size = len(nums)
        nums[0], nums[-1] = nums[-1], nums[0]
        # 得到最大值(堆顶元素)然后调整堆
        top = nums.pop()
        if size > 0:
            self.heapAdjust(nums, 0, size - 2, goal_node)

        return top

    def heapPopSmooth(self, nums: list, goal_node, current_node, last_direction):
        target = 0  # 默认使用最小的,下面的操作在各个最小值中丝滑的
        size = len(nums)
        for i in range(min(15, size)):
            node = nums[i].node
            if (node.x - current_node.x == last_direction["x"] and
                    node.y - current_node.y == last_direction["y"]):
                nums[i], nums[-1] = nums[-1], nums[i]
                target = i
                # 得到最大值(堆顶元素)然后调整堆
                break
        top = nums.pop(target)

        if size > 0:
            self.heapAdjust(nums, 0, size - 2, goal_node)
        return top

    # 升序堆排序
    def heapSort(self, nums: list, goal_node):
        self.heapify(nums)
        size = len(nums)
        for i in range(size):
            nums[0], nums[size - i - 1] = nums[size - i - 1], nums[0]
            self.heapAdjust(nums, 0, size - i - 2, goal_node)
        return nums

    def heapRemove(self, nums: list, value: Node, goal_node):
        size = len(nums)
        if size == 0:
            return False
        # 找到要删除的元素
        index = -1
        for i in range(size):
            if nums[i].node == value:
                index = i
                break
        if index == -1:
            return False  # 元素不在堆中

        # 将要删除的元素与最后一个元素交换
        nums[index], nums[-1] = nums[-1], nums[index]
        removed = nums.pop()  # 移除末尾元素

        if index < size - 1:  # 如果要删除的元素不是最后一个元素,调整堆
            # 调整堆
            self.heapAdjust(nums, index, size - 2, goal_node)
            # # 如果需要向上调整
            # parent_index = (index - 1) // 2
            # if parent_index >= 0 and (nums[index].node.cost + self.calc_heuristic(goal_node, nums[index].node) <
            #                           nums[parent_index].node.cost + self.calc_heuristic(goal_node,
            #                                                                              nums[parent_index].node)):
            #     self.heappush(nums, nums[index], goal_node)

        return removed

    @staticmethod
    def calc_heuristic(n1, n2):
        """计算启发函数

        Args:
            n1 (_type_): _description_
            n2 (_type_): _description_

        Returns:
            _type_: _description_
        """
        w = 1.0  # weight of heuristic
        d = w * math.hypot(n1.x - n2.x, n1.y - n2.y)
        # d = n1.x - n2.x + n1.y - n2.y
        return d

JPS

from __future__ import annotations  # 延迟评估注解(类型声明)

from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
from typing import List, Union
import math


class Node:
    def __init__(self, grid_pos: tuple[int, int], g: float, h: float, parent: Node = None):
        self.grid_pos = grid_pos
        self.g = g
        self.h = h
        self.parent = parent
        self.f = self.g + self.h


class JPS:
    def __init__(self, width: int, height: int):
        self.start_grid_pos = (0, 0)
        self.goal_grid_pos = (0, 0)
        self.width = width
        self.height = height

        self.open = {}  # pos:node
        self.close = {}

        self.motion_directions = [[1, 0], [0, 1], [0, -1], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]]

    def run(self, start_grid_pos: tuple[int, int], goal_grid_pos: tuple[int, int]):
        self.start_grid_pos = start_grid_pos
        self.goal_grid_pos = goal_grid_pos
        start_node = Node(start_grid_pos, 0, self.getH(self.start_grid_pos, self.goal_grid_pos))
        self.open[start_node.grid_pos] = start_node
        start_time = datetime.now()

        while True:

            # 寻路失败,结束循环
            if not self.open:
                print("未找到路径")
                break

            current_node = min(self.open.values(), key=lambda x: x.f)  # f值最小的节点

            plt.plot(current_node.grid_pos[0],
                     current_node.grid_pos[1], "xc")

            # 找到路径, 返回结果
            if current_node.grid_pos == self.goal_grid_pos:
                path = self.findPath(current_node)
                end_time = datetime.now()
                print("time cost = ", end_time - start_time)
                return path

            # 扩展节点
            # print(current_node.grid_pos)
            self.extendNode(current_node)
            # print(len(self.open))

            # 更新节点
            self.close[current_node.grid_pos] = current_node
            del self.open[current_node.grid_pos]

    def extendNode(self, current_node: Node):
        """
        根据当前节点,扩展节点(只有跳点才可以扩展)
        input
        ----------
        current_node: 当前节点对象
        """
        neighbours = self.getPruneNeighbours(current_node)
        # print(neighbours)
        for n in neighbours:
            # print("n", n)
            # print("curr", current_node.grid_pos)
            jp = self.getJumpNode(n, current_node.grid_pos)  # 跳点
            if jp:
                # print("jp", jp)
                if jp in self.close:
                    continue

                new_node = Node(jp, current_node.g + self.getG(jp, current_node.grid_pos),
                                self.getH(jp, self.goal_grid_pos), current_node)
                if jp in self.open:
                    if new_node.f < self.open[jp].f:
                        self.open[jp].parent = current_node
                        self.open[jp].f = new_node.f
                        # print("update", new_node.grid_pos)
                else:
                    self.open[jp] = new_node
                    # print("add", new_node.grid_pos)
    def getJumpNode(self, now: tuple[int, int], pre: tuple[int, int]) -> Union[tuple[int, int], None]:
        """
        计算跳点
        input
        ----------
        now: 当前节点坐标
        pre: 上一节点坐标
        output
        ----------
        若有跳点,返回跳点坐标;若无跳点,则返回None
        """
        x_direction = int((now[0] - pre[0]) / abs(now[0] - pre[0])) if now[0] - pre[0] != 0 else 0
        y_direction = int((now[1] - pre[1]) / abs(now[1] - pre[1])) if now[1] - pre[1] != 0 else 0

        if now == self.goal_grid_pos:  # 如果当前节点是终点,则为跳点(条件1)
            return now

        if self.hasForceNeighbours(now, pre):  # 如果当前节点包含强迫邻居,则为跳点(条件2)
            return now

        if abs(x_direction) + abs(y_direction) == 2:  # 若为斜线移动,则判断水平和垂直方向是否有满足上述条件1和2的点,若有,则为跳点(条件3)

            if (self.getJumpNode((now[0] + x_direction, now[1]), now) or
                    self.getJumpNode((now[0], now[1] + y_direction), now)):
                return now

        if self.isPass(now[0] + x_direction, now[1] + y_direction):  # 若当前节点未找到跳点,朝当前方向前进一步,继续寻找跳点,直至不可达
            jp = self.getJumpNode((now[0] + x_direction, now[1] + y_direction), now)
            if jp:
                return jp

        return None

    def hasForceNeighbours(self, now: tuple[int, int], pre: tuple[int, int]) -> bool:
        """
        根据当前节点坐标和上一节点坐标判断当前节点是否拥有强迫邻居
        input
        ----------
        now: 当前坐标
        pre: 上一坐标
        output
        ----------
        若拥有强迫邻居,则返回True;否则返回False
        """
        x_direction = now[0] - pre[0]
        y_direction = now[1] - pre[1]

        # 若为直线移动
        if abs(x_direction) + abs(y_direction) == 1:

            if abs(x_direction) == 1:  # 水平移动
                if (self.isPass(now[0] + x_direction, now[1] + 1) and not self.isPass(now[0], now[1] + 1)) or \
                        (self.isPass(now[0] + x_direction, now[1] - 1) and not self.isPass(now[0], now[1] - 1)):
                    return True
                else:
                    return False

            elif abs(y_direction) == 1:  # 垂直移动

                if (self.isPass(now[0] + 1, now[1] + y_direction) and not self.isPass(now[0] + 1, now[1])) or \
                        (self.isPass(now[0] - 1, now[1] + y_direction) and not self.isPass(now[0] - 1, now[1])):
                    return True
                else:
                    return False

            else:
                raise Exception("错误,直线移动中只能水平或垂直移动!")

        # 若为斜线移动
        elif abs(x_direction) + abs(y_direction) == 2:
            if (self.isPass(now[0] + x_direction, now[1] - y_direction) and not self.isPass(now[0],
                                                                                            now[1] - y_direction)) or \
                    (self.isPass(now[0] - x_direction, now[1] + y_direction) and not self.isPass(now[0] - x_direction,
                                                                                                 now[1])):
                return True
            else:
                return False

        else:
            raise Exception("错误,只能直线移动或斜线移动!")

    def getH(self, current: tuple[int, int], goal: tuple[int, int], func: str = "Euclidean") -> float:
        """
        根据当前坐标和终点坐标计算启发值,包含:
            曼哈顿距离(Manhattan Distance)
            欧几里得距离(Euclidean Distance)  默认
            切比雪夫距离(Chebyshev Distance)
        input
        ----------
        current: 当前节点坐标
        goal: 目标节点坐标
        func: 启发函数,默认为欧几里得距离
        """
        current_x = current[0]
        current_y = current[1]
        goal_x = goal[0]
        goal_y = goal[1]

        if func == "Manhattan":  # 适用于格点地图上,移动限制为上下左右四个方向的情况
            h = abs(current_x - goal_x) + abs(current_y - goal_y)

        elif func == "Euclidean":  # 适用于可以自由移动至任何方向的情况,例如在平面上自由移动
            h = math.hypot(current_x - goal_x, current_y - goal_y)

        elif func == "Chebyshev":  # 适用于八方向移动的格点地图(上下左右及对角线)
            h = max(abs(current_x - goal_x), abs(current_y - goal_y))

        else:
            raise Exception("错误,不支持该启发函数。目前支持:Manhattan、Euclidean(默认)、Chebyshev。")

        return h

    def getG(self, pos1: tuple[int, int], pos2: tuple[int, int]) -> float:
        """
        根据两坐标计算代价值(直走或斜走,直接计算欧几里得距离就可)
        input
        ----------
        pos1: 坐标1
        pos2: 坐标2
        output
        ----------
        返回代价值
        """
        return math.hypot(pos1[0] - pos2[0], pos1[1] - pos2[1])

    def getPruneNeighbours(self, current_node: Node) -> list[tuple[int, int]]:
        """
        获得裁剪之后的邻居节点坐标列表(包含自然邻居和强迫邻居,其相关概念参考:https://fallingxun.github.io/post/algorithm/algorithm_jps/)
        input
        ----------
        current_node: 当前节点对象
        output
        ----------
        prune_neighbours: 裁剪邻居坐标列表
        """
        prune_neighbours = []

        if current_node.parent:
            motion_x = int((current_node.grid_pos[0] - current_node.parent.grid_pos[0]) / abs(
                current_node.grid_pos[0] - current_node.parent.grid_pos[0])) if current_node.grid_pos[0] - \
                                                                                current_node.parent.grid_pos[
                                                                                    0] != 0 else 0  # 方向不分大小,所以要除以长度
            motion_y = int((current_node.grid_pos[1] - current_node.parent.grid_pos[1]) / abs(
                current_node.grid_pos[1] - current_node.parent.grid_pos[1])) if current_node.grid_pos[1] - \
                                                                                current_node.parent.grid_pos[
                                                                                    1] != 0 else 0  # 方向不分大小,所以要除以长度

            if abs(motion_x) + abs(motion_y) == 1:  # 直线

                # 自然邻居
                if self.isPass(current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] + motion_y):
                    prune_neighbours.append((current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] + motion_y))

                # 强迫邻居
                if abs(motion_x) == 0:  # 垂直走
                    if not self.isPass(current_node.grid_pos[0] + 1, current_node.grid_pos[1]) and self.isPass(current_node.grid_pos[0] + 1, current_node.grid_pos[1] + motion_y):
                        prune_neighbours.append((current_node.grid_pos[0] + 1, current_node.grid_pos[1] + motion_y))
                    if not self.isPass(current_node.grid_pos[0] - 1, current_node.grid_pos[1]) and self.isPass(current_node.grid_pos[0] - 1, current_node.grid_pos[1] + motion_y):
                        prune_neighbours.append((current_node.grid_pos[0] - 1, current_node.grid_pos[1] + motion_y))
                else:  # 水平走
                    if not self.isPass(current_node.grid_pos[0], current_node.grid_pos[1] + 1) and self.isPass(current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] + 1):
                        prune_neighbours.append((current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] + 1))
                    if not self.isPass(current_node.grid_pos[0], current_node.grid_pos[1] - 1) and (current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] - 1):
                        prune_neighbours.append((current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] - 1))

            elif abs(motion_x) + abs(motion_y) == 2:  # 对角线

                # 自然邻居
                if self.isPass(current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] + motion_y):
                    prune_neighbours.append((current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] + motion_y))
                if self.isPass(current_node.grid_pos[0] + motion_x, current_node.grid_pos[1]):
                    prune_neighbours.append((current_node.grid_pos[0] + motion_x, current_node.grid_pos[1]))
                if self.isPass(current_node.grid_pos[0], current_node.grid_pos[1] + motion_y):
                    prune_neighbours.append((current_node.grid_pos[0], current_node.grid_pos[1] + motion_y))

                # 强迫邻居
                if not self.isPass(current_node.grid_pos[0] - motion_x, current_node.grid_pos[1]) and self.isPass(
                        current_node.grid_pos[0] - motion_x, current_node.grid_pos[1] + motion_y):
                    prune_neighbours.append((current_node.grid_pos[0] - motion_x, current_node.grid_pos[1] + motion_y))
                if not self.isPass(current_node.grid_pos[0], current_node.grid_pos[1] - motion_y) and self.isPass(
                        current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] - motion_y):
                    prune_neighbours.append((current_node.grid_pos[0] + motion_x, current_node.grid_pos[1] - motion_y))

            else:
                raise Exception("错误,只能对角线和直线行走!")

        else:
            for dir in self.motion_directions:
                if self.isPass(current_node.grid_pos[0] + dir[0], current_node.grid_pos[1] + dir[1]):
                    prune_neighbours.append((current_node.grid_pos[0] + dir[0], current_node.grid_pos[1] + dir[1]))

        return prune_neighbours

    def isPass(self, grid_x: int, grid_y: int) -> bool:
        """
        判断该栅格坐标是否可以通过
        input
        ----------
        grid_x: 栅格x坐标
        grid_y: 栅格y坐标
        output
        ----------
        若可以通过,为True,反之为False
        """
        if 0 <= grid_x < self.width and 0 <= grid_y < self.height:
            if obstacle_map[grid_x][grid_y] != 1 or [grid_x, grid_y] == self.goal_grid_pos:
                return True
            else:
                return False
        else:
            return False

    def findPath(self, node: Node) -> list[list[int]]:
        """
        根据给定节点回溯到开始点,找到路径
        """
        path_x = [node.grid_pos[0]]
        path_y = [node.grid_pos[1]]
        while node.parent:
            # print(node.grid_pos[0], node.grid_pos[1])
            node = node.parent
            path_x.append(node.grid_pos[0])
            path_y.append(node.grid_pos[1])

        return [path_x[::-1], path_y[::-1]]

if __name__ == '__main__':
    with open('map.bytes', 'rb') as file:

        width = int.from_bytes(file.read(2), byteorder='big')
        height = int.from_bytes(file.read(2), byteorder='big')
        print("地图宽度", width)
        print("地图高度", height)
        content = bytearray()
        while True:
            chunk = file.read(1)
            if not chunk:
                break
            content.extend(chunk)

        obstacle_map = [[0 for _ in range(width)] for _ in range(height)]

        index = 0
        for row in range(height):
            for col in range(width):
                byte_value = content[index]
                if byte_value == 0:
                    obstacle_map[row][col] = 0
                else:
                    obstacle_map[row][col] = 1
                index += 1
    start = (400, 300)
    goal = (600, 600)

    jps = JPS(1500, 1500)
    path = jps.run(start, goal)
    plt.plot(path[0], path[1])

    fig = plt.figure()

    ax = fig.add_subplot(111)
    ax.imshow(np.array(obstacle_map), cmap='binary', vmin=0, vmax=1)

    ax.scatter(start[1], start[0])  # 绘制图像的坐标轴与np相反
    ax.scatter(goal[1], goal[0])

    px, py = path
    ax.plot(py, px)
    plt.show()