【强化学习】求解迷宫寻路问题(基于Q-Learning)+Java代码实现

831 阅读7分钟

前言

相信大多数小伙伴应该和我一样,之前在学习强化学习的时候,一直用的是Python,但奈何只会用java写后端,对Python的一些后端框架还不太熟悉,(以后要集成到网站上就惨了),于是就想用Java实现一下强化学习中的Q-Learning算法,来搜索求解人工智能领域较热门的问题---迷宫寻路问题。(避免以后要用的时候来不及写)。

一、强化学习简介

强化学习是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益。其灵感来源于心理学中的行为主义理论,即有机体如何在环境给予的奖励或惩罚的刺激下,逐步形成对刺激的预期,产生能获得最大利益的习惯性行为。

  • 1956年Bellman提出了动态规划方法。
  • 1977年Werbos提出只适应动态规划算法。
  • 1988年sutton提出时间差分算法。
  • 1992年Watkins 提出Q-learning 算法。
  • 1994年rummery 提出Saras算法。
  • 1996年Bersekas提出解决随机过程中优化控制的神经动态规划方法。
  • 2006年Kocsis提出了置信上限树算法。
  • 2009年kewis提出反馈控制只适应动态规划算法。
  • 2014年silver提出确定性策略梯度(Policy Gradients)算法。
  • 2015年Google-deepmind 提出Deep-Q-Network算法。

可见,强化学习已经发展了几十年,并不是一门新的技术。在2016年,AlphaGo击败李世石之后,融合了深度学习的强化学习技术大放异彩,成为这两年最火的技术之一。总结来说,强化学习就是一个古老而又时尚的技术。


二、强化学习方法汇总

1.Modelfree 和 Modelbased

在这里插入图片描述

2.基于概率 和 基于价值

在这里插入图片描述

3.回合更新 和 单步更新

在这里插入图片描述

4.在线学习 和 离线学习

在这里插入图片描述


三、迷宫寻路问题简介

迷宫寻路问题是人工智能中的有趣问题,给定一个M行N列的迷宫图,其中 "0"表示可通路,"1"表示障碍物,无法通行,"2"表示起点,"3"表示终点。在迷宫中只允许在水平或上下四个方向的通路上行走,走过的位置不能重复走,需要搜索出从起点到终点尽量短的路径。 在这里插入图片描述 地图可视化如下图所示:绿色代表道路,黑色代表墙壁,粉色代表起点,蓝色代表终点 在这里插入图片描述


四、强化学习求解思路

迷宫寻路问题中常用的搜索算法有A*算法、递归法、启发式算法等。但是,不管选用何种算法,如何表示状态空间和搜索路径都是寻路问题的重点。 本文基于Q-Learning算法(什么是Q-Learning?),采用Q-Table进行状态空间和动作收获值的存储。 Q-Learning又分为两大类,一种是蒙特卡罗学习,一种是时序差分学习。依旧套用围棋的例子,简单来讲蒙特卡罗学习是下完一盘棋之后,批处理更新一次分值。时序差分学习是每下一步棋,就更新一次分值。理论上讲,时序差分学习更好一些,时间性能比较好,学习的比较快。故小编采用的就是时序差分学习模式。


五、java代码

1.辅助类-Instance

用于存储地图信息,并提供初始化Q表和查询起点和终点索引的方法

package RL_Path_Planning;

import lombok.Data;

/*
**  Create by: WSKH0929
    Date:2021-11-10
    Time:21:24
*/
@Data
public class Instance {
    // 0是路 1是墙 2是起点 3是终点
    private int[][] map = new int[][]{
            {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0},
            {0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0},
            {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0},
            {0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
            {0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0},
            {0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
            {0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
            {0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0},
            {0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
            {0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
            {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
            {0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0},
            {0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0},
            {0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0},
            {0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0},
            {0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1},
            {0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1},
            {0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1},
            {0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1},
            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0},
    };
    private double[][] Q;
    private int startIndex, endIndex;

    // 获取起点和终点的序号
    public void initStart_And_End_Index() {
        for (int i = 0; i < map.length; i++) {
            for (int j = 0; j < map[i].length; j++) {
                if (map[i][j] == 2) {
                    startIndex = i * map[0].length + j;
                } else if (map[i][j] == 3) {
                    endIndex = i * map[0].length + j;
                }
            }
        }
    }

    // 初始化Q表
    public void initQ() {
        // 对应四种运动可能 上下左右
        Q = new double[map.length * map[0].length][4];
        for (int i = 0; i < Q.length; i++) {
            for (int j = 0; j < Q[i].length; j++) {
                Q[i][j] = 0.001;
            }
        }
    }
}

2.主类-Application

这里用到了第三方库JavaFX,对地图和最终搜索到的路线进行了可视化展示。

package RL_Path_Planning;


import javafx.event.ActionEvent;
import javafx.event.EventHandler;
import javafx.geometry.VPos;
import javafx.scene.Scene;
import javafx.scene.canvas.Canvas;
import javafx.scene.canvas.GraphicsContext;
import javafx.scene.control.Alert;
import javafx.scene.control.Button;
import javafx.scene.layout.AnchorPane;
import javafx.scene.paint.Color;
import javafx.scene.text.Font;
import javafx.scene.text.TextAlignment;
import javafx.stage.Stage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

/*
**  Create by: WSKH0929
    Date:2021-11-10
    Time:21:09
*/
public class Application extends javafx.application.Application {
    // 超出边界
    public double beyondBoundaryReward = -1000000;
    // 撞墙
    public double hitWallReward = -1000000;
    // 普通走路
    public double commonReward = -0.00001;
    // 到达终点
    public double reachEndReward = 10000;
    // 远离终点
    public double awayEndReward = -0.0001;
    // 奖励获取率
    public double alpha = 0.6;
    // reward衰减率 怕死系数 越接近1越怕死
    public double gamma_pasi = 0.00001;
    // reward衰减率 贪心系数 越接近1越贪心
    public double gamma_tanxin = 0.00001;
    // 设置邻域探索次数
    public int max_N = 400;
    // 学习迭代次数 999999 999999
    public int T = 50000;
    //
    public int bestPathLen = Integer.MAX_VALUE;
    // 记录最佳路线
    List<Integer[]> bestPathList = new ArrayList<>();

    // 保存每次迭代的路线
    List<Integer[]> pathList = new ArrayList<>();

    @Override
    public void start(Stage primaryStage) throws Exception {
        AnchorPane pane = new AnchorPane();
        Instance instance = new Instance();
        instance.initQ(); // 初始化Q表
        instance.initStart_And_End_Index(); // 初始化起点终点索引
        Canvas canvas = initCanvas(instance.getMap());
        pane.getChildren().add(canvas);
        // 开始学习
        learn(instance, canvas);
        primaryStage.setTitle("强化学习路径规划");
        primaryStage.setScene(new Scene(pane, 600, 600, Color.YELLOW));
        primaryStage.show();
    }

    // 绘制初始地图
    public Canvas initCanvas(int[][] map) {
        Canvas canvas = new Canvas(400, 400);
        canvas.relocate(100, 100);
        for (int i = 0; i < map.length; i++) {
            for (int j = 0; j < map[i].length; j++) {
                int m = map[i][j];
                GraphicsContext gc = canvas.getGraphicsContext2D();
                if (m == 0) {
                    gc.setFill(Color.GREEN);
                } else if (m == 1) {
                    gc.setFill(Color.BLACK);
                } else if (m == 2) {
                    gc.setFill(Color.PINK);
                } else if (m == 3) {
                    gc.setFill(Color.AQUA);
                }
                gc.fillRect(j * 20, i * 20, 20, 20);
            }
        }
        return canvas;
    }

    // 学习过程
    public void learn(Instance instance, Canvas canvas) {

        // 获取map
        int[][] map = copyMap(instance.getMap());
        // 获取Q表
        double[][] Q = instance.getQ();
        for (int i = 0; i < T; i++) {
//            System.out.println(i);
            // 设置当前点所在位置 即起点
            int pos = instance.getStartIndex();
//            System.out.println("起点是: "+pos);
            // 循环探索移动
            int n = 0;
            int[][] cloneMap = copyMap(map);
            pathList = new ArrayList<>();

            while (pos >= 0 && pos != Integer.MAX_VALUE && n < max_N) {

                int colLen = cloneMap[0].length;
                int x = pos % colLen;
                int y = (pos - x) / colLen;
                double[] Qs = Q[pos];

                int next = next(Qs[0], Qs[1], Qs[2], Qs[3]);

                double tempPos = move(pos, next, cloneMap);
                // 找到Q最大的值
                double[] clone = Qs.clone();
                Arrays.sort(clone);
                double max_q = clone[clone.length - 1];
                // 更改Q表
                if (tempPos == Integer.MAX_VALUE) {
                    if (true) {
                        System.out.println("迭代数:" + i + ",PathLen = " + (n + 1));
                    }
//                    Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(reachEndReward+gamma*max_q));
                    Q[pos][next - 1] += reachEndReward;
                    pathList.add(new Integer[]{pos, next - 1});
                    // 按照路线更新路线Q值 衰减率为gamma
                    if (n + 1 < bestPathLen || false) {
                        bestPathList = new ArrayList<>(pathList);
                        bestPathLen = n + 1;
//                        System.out.println("迭代数:"+i+",bestPathLen"+bestPathLen);
                        double reward = reachEndReward * gamma_tanxin / (double) (n + 1);
                        for (int k = pathList.size() - 1; k >= 0; k--) {
                            Integer[] integers = pathList.get(k);
                            Q[integers[0]][integers[1]] += reward;
                            reward = reachEndReward * gamma_tanxin;
                        }
                    }

                } else if (tempPos >= 0) {
                    // next!=4&&next!=1
                    if (next != 1 && next != 3) {
//                        Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(commonReward+gamma*max_q));
                        Q[pos][next - 1] += commonReward;
                    } else {
//                        Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(awayEndReward+gamma*max_q));
                        Q[pos][next - 1] += awayEndReward;
                    }
                    pathList.add(new Integer[]{pos, next - 1});
                } else {
//                    Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(tempPos+gamma*max_q));
                    Q[pos][next - 1] += tempPos;
                    // 按照路线更新路线Q值 衰减率为gamma
                    double reward = tempPos * gamma_pasi;
                    for (int k = pathList.size() - 1; k >= 0; k--) {
                        Integer[] integers = pathList.get(k);
                        Q[integers[0]][integers[1]] += reward;
                        reward = tempPos * gamma_pasi;
                    }
                }

                pos = (int) tempPos;

//                System.out.println(pos);
                n++;

            }
        }
        // 学习完毕 绘图 验证
//        verification_Q(canvas,Q,instance);
        // 绘制最佳路线
        plotBestPath(canvas, instance);
    }


    // 传入四个数,按照概率返回 上下左右--->1234
    public int next(double n1, double n2, double n3, double n4) {
        // 累积概率数组
        double[] accumulateRateArr = new double[4];
        // 概率数组
        double[] rateArr = new double[4];
        // 获取最小值
        double[] arr = {n1, n2, n3, n4};
        Arrays.sort(arr);
        double min = arr[0];
        // 调整四个数 +1 是防止数值为0的情况 尽可能让所有方向都有一定概率试探
        n1 = n1 - min + Math.abs(commonReward);
        n2 = n2 - min + Math.abs(commonReward);
        n3 = n3 - min + Math.abs(commonReward);
        n4 = n4 - min + Math.abs(commonReward);
        // 根据四个数计算概率
        double sum = n1 + n2 + n3 + n4;
        rateArr[0] = (double) n1 / (double) sum;
        rateArr[1] = (double) n2 / (double) sum;
        rateArr[2] = (double) n3 / (double) sum;
        rateArr[3] = (double) n4 / (double) sum;
        // 计算累积概率
        accumulateRateArr[0] = rateArr[0];
        accumulateRateArr[1] = rateArr[0] + rateArr[1];
        accumulateRateArr[2] = rateArr[0] + rateArr[1] + rateArr[2];
        accumulateRateArr[3] = rateArr[0] + rateArr[1] + rateArr[2] + rateArr[3];
        // 轮盘赌
        double r = new Random().nextDouble();
//        System.out.println("r = "+r+","+Arrays.toString(accumulateRateArr));

        // 根据轮盘赌随机数选择方向
        if (r <= accumulateRateArr[0]) {
//            System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+1);
            return 1;
        } else if (r <= accumulateRateArr[1]) {
//            System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+2);
            return 2;
        } else if (r <= accumulateRateArr[2]) {
//            System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+3);
            return 3;
        } else {
//            System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+4);
            return 4;
        }
    }

    // 移动指令 获取移动后的坐标 如果移动返回为负数 则为惩罚 否则则说明成功移动
    // 到终点则返回 Integer.MAX_VALUE
    public double move(int pos, int direction, int[][] map) {
        int colLen = map[0].length;
        // 首先将pos转化为二维坐标
        int j = pos % colLen;
        int i = (pos - j) / colLen;
        int tempI = 0, tempJ = 0;
        switch (direction) {
            case 1:
                // 上移
                tempI = i - 1;
                tempJ = j;
//                return -100000;
                break;
            case 2:
                // 下移
                tempI = i + 1;
                tempJ = j;
                break;
            case 3:
                // 左移
                tempI = i;
                tempJ = j - 1;
//                return -100000;
                break;
            case 4:
                // 右移
                tempI = i;
                tempJ = j + 1;
                break;
            default:
        }
        if (tempI >= map.length || tempJ >= map[0].length || tempI < 0 || tempJ < 0) {
            // 超出边界
            return beyondBoundaryReward;
        } else if (map[tempI][tempJ] == 1) {
            // 撞墙
            return hitWallReward;
        } else if (map[tempI][tempJ] == 3) {
            // 到终点了
            return Integer.MAX_VALUE;
        }
        // 走了之后 会把上一个位置变成墙 防止重复探索
        map[tempI][tempJ] = 1;
        return (tempI * colLen + tempJ);
    }

    // 绘制最佳路线
    public void plotBestPath(Canvas canvas, Instance instance) {

        int[][] map = instance.getMap();
        System.out.println("起点为:" + instance.getStartIndex());
        for (Integer[] integers : bestPathList) {
            System.out.print(Arrays.toString(integers));
        }
        System.out.println();
        for (int i = 0; i < bestPathList.size(); i++) {
            int pos = bestPathList.get(i)[0];
            int colLen = map[0].length;
            int y = pos % colLen;
            int x = (pos - y) / colLen;

            GraphicsContext gc = canvas.getGraphicsContext2D();
            gc.setFill(Color.GRAY);
            gc.fillRect(y * 20, x * 20, 20, 20);

            // 绘制文字
            gc.setFill(Color.BLACK);
            gc.setFont(new Font("微软雅黑", 15));
            gc.setTextAlign(TextAlignment.CENTER);
            gc.setTextBaseline(VPos.TOP);
            gc.fillText("" + (i + 1), y * 20 + 10, x * 20);
        }

        System.out.println("最少步数为:" + bestPathLen);
    }

    // 按照Q表进行移动 每次选择Q值最大的移动方向 并且绘图
    public void verification_Q(Canvas canvas, double[][] Q, Instance instance) {
        System.out.println("最终Q表为:");
        for (double[] q : Q) {
            System.out.println(Arrays.toString(q));
        }
        System.out.println("地图为:");
        int pos = instance.getStartIndex();
        int[][] map = instance.getMap();
        for (int i = 0; i < map.length; i++) {
            System.out.println(Arrays.toString(map[i]));
        }
        int n = 0;
        while (pos >= 0 && pos != Integer.MAX_VALUE && n < 500) {

            int colLen = map[0].length;
            int j = pos % colLen;
            int i = (pos - j) / colLen;

            GraphicsContext gc = canvas.getGraphicsContext2D();
            gc.setFill(Color.GRAY);
            gc.fillRect(j * 20, i * 20, 20, 20);

            // 绘制文字
            gc.setFill(Color.BLACK);
            gc.setFont(new Font("微软雅黑", 15));
            gc.setTextAlign(TextAlignment.CENTER);
            gc.setTextBaseline(VPos.TOP);
            gc.fillText("" + (n + 1), j * 20 + 10, i * 20);

            pos = moveToMax(pos, Q[pos].clone(), map);

            n++;

        }
        System.out.println("步数为:" + n);
    }

    // 向最小化惩罚的方向移动
    public int moveToMax(int pos, double[] Qs, int[][] map) {
        int colLen = map[0].length;
        // 首先将pos转化为二维坐标
        int j = pos % colLen;
        int i = (pos - j) / colLen;
        int tempI = 0, tempJ = 0;
        // 获取最大化奖励的方向
        int maxIndex = -1;
        double maxValue = -1 * Double.MAX_VALUE;
        for (int k = 0; k < Qs.length; k++) {
            // &&map[i][j]!=1
            if (Qs[k] > maxValue && map[i][j] != 1) {
                maxValue = Qs[k];
                maxIndex = k;
            }
        }
        // 移动坐标
        switch (maxIndex + 1) {
            case 1:
                // 上移
                tempI = i - 1;
                tempJ = j;
                break;
            case 2:
                // 下移
                tempI = i + 1;
                tempJ = j;
                break;
            case 3:
                // 左移
                tempI = i;
                tempJ = j - 1;
                break;
            case 4:
                // 右移
                tempI = i;
                tempJ = j + 1;
                break;
            default:
        }
        System.out.println(pos + "  " + Arrays.toString(Qs) + "  maxIndex+1 = " + (maxIndex + 1));
        if (tempI >= map.length || tempJ >= map[0].length || tempI < 0 || tempJ < 0) {
            // 超出边界
            return -100;
        } else if (map[tempI][tempJ] == 1) {
            // 撞墙
            return -100;
        } else if (map[tempI][tempJ] == 3) {
            // 到终点了
            return Integer.MAX_VALUE;
        }
        //
        map[i][j] = 1;
        return (tempI * colLen + tempJ);
    }

    // 复制map
    public int[][] copyMap(int[][] map) {
        int[][] tempMap = new int[map.length][map[0].length];
        for (int i = 0; i < tempMap.length; i++) {
            tempMap[i] = map[i].clone();
        }
        return tempMap;
    }

    public static void main(String[] args) {
        launch(args);
    }

}

3.运行结果展示

搜索到最优路径为:32 在这里插入图片描述 尝试不同位置的起点和终点: 在这里插入图片描述 尝试加入多个终点: 在这里插入图片描述 在这里插入图片描述


以上就是完整代码啦!如果觉得感兴趣,欢迎点赞+关注,以后会继续更新相关方面的文章!