前言
相信大多数小伙伴应该和我一样,之前在学习强化学习的时候,一直用的是Python,但奈何只会用java写后端,对Python的一些后端框架还不太熟悉,(以后要集成到网站上就惨了),于是就想用Java实现一下强化学习中的Q-Learning算法,来搜索求解人工智能领域较热门的问题---迷宫寻路问题。(避免以后要用的时候来不及写)。
一、强化学习简介
强化学习是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益。其灵感来源于心理学中的行为主义理论,即有机体如何在环境给予的奖励或惩罚的刺激下,逐步形成对刺激的预期,产生能获得最大利益的习惯性行为。
- 1956年Bellman提出了动态规划方法。
- 1977年Werbos提出只适应动态规划算法。
- 1988年sutton提出时间差分算法。
- 1992年Watkins 提出Q-learning 算法。
- 1994年rummery 提出Saras算法。
- 1996年Bersekas提出解决随机过程中优化控制的神经动态规划方法。
- 2006年Kocsis提出了置信上限树算法。
- 2009年kewis提出反馈控制只适应动态规划算法。
- 2014年silver提出确定性策略梯度(Policy Gradients)算法。
- 2015年Google-deepmind 提出Deep-Q-Network算法。
可见,强化学习已经发展了几十年,并不是一门新的技术。在2016年,AlphaGo击败李世石之后,融合了深度学习的强化学习技术大放异彩,成为这两年最火的技术之一。总结来说,强化学习就是一个古老而又时尚的技术。
二、强化学习方法汇总
1.Modelfree 和 Modelbased
2.基于概率 和 基于价值
3.回合更新 和 单步更新
4.在线学习 和 离线学习
三、迷宫寻路问题简介
迷宫寻路问题是人工智能中的有趣问题,给定一个M行N列的迷宫图,其中 "0"表示可通路,"1"表示障碍物,无法通行,"2"表示起点,"3"表示终点。在迷宫中只允许在水平或上下四个方向的通路上行走,走过的位置不能重复走,需要搜索出从起点到终点尽量短的路径。
地图可视化如下图所示:绿色代表道路,黑色代表墙壁,粉色代表起点,蓝色代表终点
四、强化学习求解思路
迷宫寻路问题中常用的搜索算法有A*算法、递归法、启发式算法等。但是,不管选用何种算法,如何表示状态空间和搜索路径都是寻路问题的重点。 本文基于Q-Learning算法(什么是Q-Learning?),采用Q-Table进行状态空间和动作收获值的存储。 Q-Learning又分为两大类,一种是蒙特卡罗学习,一种是时序差分学习。依旧套用围棋的例子,简单来讲蒙特卡罗学习是下完一盘棋之后,批处理更新一次分值。时序差分学习是每下一步棋,就更新一次分值。理论上讲,时序差分学习更好一些,时间性能比较好,学习的比较快。故小编采用的就是时序差分学习模式。
五、java代码
1.辅助类-Instance
用于存储地图信息,并提供初始化Q表和查询起点和终点索引的方法
package RL_Path_Planning;
import lombok.Data;
/*
** Create by: WSKH0929
Date:2021-11-10
Time:21:24
*/
@Data
public class Instance {
// 0是路 1是墙 2是起点 3是终点
private int[][] map = new int[][]{
{1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0},
{0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0},
{0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0},
{0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
{0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0},
{0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0},
{0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0},
{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0},
{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0},
{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0},
{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1},
{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1},
{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1},
{0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1},
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0},
};
private double[][] Q;
private int startIndex, endIndex;
// 获取起点和终点的序号
public void initStart_And_End_Index() {
for (int i = 0; i < map.length; i++) {
for (int j = 0; j < map[i].length; j++) {
if (map[i][j] == 2) {
startIndex = i * map[0].length + j;
} else if (map[i][j] == 3) {
endIndex = i * map[0].length + j;
}
}
}
}
// 初始化Q表
public void initQ() {
// 对应四种运动可能 上下左右
Q = new double[map.length * map[0].length][4];
for (int i = 0; i < Q.length; i++) {
for (int j = 0; j < Q[i].length; j++) {
Q[i][j] = 0.001;
}
}
}
}
2.主类-Application
这里用到了第三方库JavaFX,对地图和最终搜索到的路线进行了可视化展示。
package RL_Path_Planning;
import javafx.event.ActionEvent;
import javafx.event.EventHandler;
import javafx.geometry.VPos;
import javafx.scene.Scene;
import javafx.scene.canvas.Canvas;
import javafx.scene.canvas.GraphicsContext;
import javafx.scene.control.Alert;
import javafx.scene.control.Button;
import javafx.scene.layout.AnchorPane;
import javafx.scene.paint.Color;
import javafx.scene.text.Font;
import javafx.scene.text.TextAlignment;
import javafx.stage.Stage;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/*
** Create by: WSKH0929
Date:2021-11-10
Time:21:09
*/
public class Application extends javafx.application.Application {
// 超出边界
public double beyondBoundaryReward = -1000000;
// 撞墙
public double hitWallReward = -1000000;
// 普通走路
public double commonReward = -0.00001;
// 到达终点
public double reachEndReward = 10000;
// 远离终点
public double awayEndReward = -0.0001;
// 奖励获取率
public double alpha = 0.6;
// reward衰减率 怕死系数 越接近1越怕死
public double gamma_pasi = 0.00001;
// reward衰减率 贪心系数 越接近1越贪心
public double gamma_tanxin = 0.00001;
// 设置邻域探索次数
public int max_N = 400;
// 学习迭代次数 999999 999999
public int T = 50000;
//
public int bestPathLen = Integer.MAX_VALUE;
// 记录最佳路线
List<Integer[]> bestPathList = new ArrayList<>();
// 保存每次迭代的路线
List<Integer[]> pathList = new ArrayList<>();
@Override
public void start(Stage primaryStage) throws Exception {
AnchorPane pane = new AnchorPane();
Instance instance = new Instance();
instance.initQ(); // 初始化Q表
instance.initStart_And_End_Index(); // 初始化起点终点索引
Canvas canvas = initCanvas(instance.getMap());
pane.getChildren().add(canvas);
// 开始学习
learn(instance, canvas);
primaryStage.setTitle("强化学习路径规划");
primaryStage.setScene(new Scene(pane, 600, 600, Color.YELLOW));
primaryStage.show();
}
// 绘制初始地图
public Canvas initCanvas(int[][] map) {
Canvas canvas = new Canvas(400, 400);
canvas.relocate(100, 100);
for (int i = 0; i < map.length; i++) {
for (int j = 0; j < map[i].length; j++) {
int m = map[i][j];
GraphicsContext gc = canvas.getGraphicsContext2D();
if (m == 0) {
gc.setFill(Color.GREEN);
} else if (m == 1) {
gc.setFill(Color.BLACK);
} else if (m == 2) {
gc.setFill(Color.PINK);
} else if (m == 3) {
gc.setFill(Color.AQUA);
}
gc.fillRect(j * 20, i * 20, 20, 20);
}
}
return canvas;
}
// 学习过程
public void learn(Instance instance, Canvas canvas) {
// 获取map
int[][] map = copyMap(instance.getMap());
// 获取Q表
double[][] Q = instance.getQ();
for (int i = 0; i < T; i++) {
// System.out.println(i);
// 设置当前点所在位置 即起点
int pos = instance.getStartIndex();
// System.out.println("起点是: "+pos);
// 循环探索移动
int n = 0;
int[][] cloneMap = copyMap(map);
pathList = new ArrayList<>();
while (pos >= 0 && pos != Integer.MAX_VALUE && n < max_N) {
int colLen = cloneMap[0].length;
int x = pos % colLen;
int y = (pos - x) / colLen;
double[] Qs = Q[pos];
int next = next(Qs[0], Qs[1], Qs[2], Qs[3]);
double tempPos = move(pos, next, cloneMap);
// 找到Q最大的值
double[] clone = Qs.clone();
Arrays.sort(clone);
double max_q = clone[clone.length - 1];
// 更改Q表
if (tempPos == Integer.MAX_VALUE) {
if (true) {
System.out.println("迭代数:" + i + ",PathLen = " + (n + 1));
}
// Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(reachEndReward+gamma*max_q));
Q[pos][next - 1] += reachEndReward;
pathList.add(new Integer[]{pos, next - 1});
// 按照路线更新路线Q值 衰减率为gamma
if (n + 1 < bestPathLen || false) {
bestPathList = new ArrayList<>(pathList);
bestPathLen = n + 1;
// System.out.println("迭代数:"+i+",bestPathLen"+bestPathLen);
double reward = reachEndReward * gamma_tanxin / (double) (n + 1);
for (int k = pathList.size() - 1; k >= 0; k--) {
Integer[] integers = pathList.get(k);
Q[integers[0]][integers[1]] += reward;
reward = reachEndReward * gamma_tanxin;
}
}
} else if (tempPos >= 0) {
// next!=4&&next!=1
if (next != 1 && next != 3) {
// Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(commonReward+gamma*max_q));
Q[pos][next - 1] += commonReward;
} else {
// Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(awayEndReward+gamma*max_q));
Q[pos][next - 1] += awayEndReward;
}
pathList.add(new Integer[]{pos, next - 1});
} else {
// Q[pos][next-1] += ((1-alpha)*Q[pos][next-1]+alpha*(tempPos+gamma*max_q));
Q[pos][next - 1] += tempPos;
// 按照路线更新路线Q值 衰减率为gamma
double reward = tempPos * gamma_pasi;
for (int k = pathList.size() - 1; k >= 0; k--) {
Integer[] integers = pathList.get(k);
Q[integers[0]][integers[1]] += reward;
reward = tempPos * gamma_pasi;
}
}
pos = (int) tempPos;
// System.out.println(pos);
n++;
}
}
// 学习完毕 绘图 验证
// verification_Q(canvas,Q,instance);
// 绘制最佳路线
plotBestPath(canvas, instance);
}
// 传入四个数,按照概率返回 上下左右--->1234
public int next(double n1, double n2, double n3, double n4) {
// 累积概率数组
double[] accumulateRateArr = new double[4];
// 概率数组
double[] rateArr = new double[4];
// 获取最小值
double[] arr = {n1, n2, n3, n4};
Arrays.sort(arr);
double min = arr[0];
// 调整四个数 +1 是防止数值为0的情况 尽可能让所有方向都有一定概率试探
n1 = n1 - min + Math.abs(commonReward);
n2 = n2 - min + Math.abs(commonReward);
n3 = n3 - min + Math.abs(commonReward);
n4 = n4 - min + Math.abs(commonReward);
// 根据四个数计算概率
double sum = n1 + n2 + n3 + n4;
rateArr[0] = (double) n1 / (double) sum;
rateArr[1] = (double) n2 / (double) sum;
rateArr[2] = (double) n3 / (double) sum;
rateArr[3] = (double) n4 / (double) sum;
// 计算累积概率
accumulateRateArr[0] = rateArr[0];
accumulateRateArr[1] = rateArr[0] + rateArr[1];
accumulateRateArr[2] = rateArr[0] + rateArr[1] + rateArr[2];
accumulateRateArr[3] = rateArr[0] + rateArr[1] + rateArr[2] + rateArr[3];
// 轮盘赌
double r = new Random().nextDouble();
// System.out.println("r = "+r+","+Arrays.toString(accumulateRateArr));
// 根据轮盘赌随机数选择方向
if (r <= accumulateRateArr[0]) {
// System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+1);
return 1;
} else if (r <= accumulateRateArr[1]) {
// System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+2);
return 2;
} else if (r <= accumulateRateArr[2]) {
// System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+3);
return 3;
} else {
// System.out.println(Arrays.toString(new int[]{n1, n2, n3, n4})+" -> "+4);
return 4;
}
}
// 移动指令 获取移动后的坐标 如果移动返回为负数 则为惩罚 否则则说明成功移动
// 到终点则返回 Integer.MAX_VALUE
public double move(int pos, int direction, int[][] map) {
int colLen = map[0].length;
// 首先将pos转化为二维坐标
int j = pos % colLen;
int i = (pos - j) / colLen;
int tempI = 0, tempJ = 0;
switch (direction) {
case 1:
// 上移
tempI = i - 1;
tempJ = j;
// return -100000;
break;
case 2:
// 下移
tempI = i + 1;
tempJ = j;
break;
case 3:
// 左移
tempI = i;
tempJ = j - 1;
// return -100000;
break;
case 4:
// 右移
tempI = i;
tempJ = j + 1;
break;
default:
}
if (tempI >= map.length || tempJ >= map[0].length || tempI < 0 || tempJ < 0) {
// 超出边界
return beyondBoundaryReward;
} else if (map[tempI][tempJ] == 1) {
// 撞墙
return hitWallReward;
} else if (map[tempI][tempJ] == 3) {
// 到终点了
return Integer.MAX_VALUE;
}
// 走了之后 会把上一个位置变成墙 防止重复探索
map[tempI][tempJ] = 1;
return (tempI * colLen + tempJ);
}
// 绘制最佳路线
public void plotBestPath(Canvas canvas, Instance instance) {
int[][] map = instance.getMap();
System.out.println("起点为:" + instance.getStartIndex());
for (Integer[] integers : bestPathList) {
System.out.print(Arrays.toString(integers));
}
System.out.println();
for (int i = 0; i < bestPathList.size(); i++) {
int pos = bestPathList.get(i)[0];
int colLen = map[0].length;
int y = pos % colLen;
int x = (pos - y) / colLen;
GraphicsContext gc = canvas.getGraphicsContext2D();
gc.setFill(Color.GRAY);
gc.fillRect(y * 20, x * 20, 20, 20);
// 绘制文字
gc.setFill(Color.BLACK);
gc.setFont(new Font("微软雅黑", 15));
gc.setTextAlign(TextAlignment.CENTER);
gc.setTextBaseline(VPos.TOP);
gc.fillText("" + (i + 1), y * 20 + 10, x * 20);
}
System.out.println("最少步数为:" + bestPathLen);
}
// 按照Q表进行移动 每次选择Q值最大的移动方向 并且绘图
public void verification_Q(Canvas canvas, double[][] Q, Instance instance) {
System.out.println("最终Q表为:");
for (double[] q : Q) {
System.out.println(Arrays.toString(q));
}
System.out.println("地图为:");
int pos = instance.getStartIndex();
int[][] map = instance.getMap();
for (int i = 0; i < map.length; i++) {
System.out.println(Arrays.toString(map[i]));
}
int n = 0;
while (pos >= 0 && pos != Integer.MAX_VALUE && n < 500) {
int colLen = map[0].length;
int j = pos % colLen;
int i = (pos - j) / colLen;
GraphicsContext gc = canvas.getGraphicsContext2D();
gc.setFill(Color.GRAY);
gc.fillRect(j * 20, i * 20, 20, 20);
// 绘制文字
gc.setFill(Color.BLACK);
gc.setFont(new Font("微软雅黑", 15));
gc.setTextAlign(TextAlignment.CENTER);
gc.setTextBaseline(VPos.TOP);
gc.fillText("" + (n + 1), j * 20 + 10, i * 20);
pos = moveToMax(pos, Q[pos].clone(), map);
n++;
}
System.out.println("步数为:" + n);
}
// 向最小化惩罚的方向移动
public int moveToMax(int pos, double[] Qs, int[][] map) {
int colLen = map[0].length;
// 首先将pos转化为二维坐标
int j = pos % colLen;
int i = (pos - j) / colLen;
int tempI = 0, tempJ = 0;
// 获取最大化奖励的方向
int maxIndex = -1;
double maxValue = -1 * Double.MAX_VALUE;
for (int k = 0; k < Qs.length; k++) {
// &&map[i][j]!=1
if (Qs[k] > maxValue && map[i][j] != 1) {
maxValue = Qs[k];
maxIndex = k;
}
}
// 移动坐标
switch (maxIndex + 1) {
case 1:
// 上移
tempI = i - 1;
tempJ = j;
break;
case 2:
// 下移
tempI = i + 1;
tempJ = j;
break;
case 3:
// 左移
tempI = i;
tempJ = j - 1;
break;
case 4:
// 右移
tempI = i;
tempJ = j + 1;
break;
default:
}
System.out.println(pos + " " + Arrays.toString(Qs) + " maxIndex+1 = " + (maxIndex + 1));
if (tempI >= map.length || tempJ >= map[0].length || tempI < 0 || tempJ < 0) {
// 超出边界
return -100;
} else if (map[tempI][tempJ] == 1) {
// 撞墙
return -100;
} else if (map[tempI][tempJ] == 3) {
// 到终点了
return Integer.MAX_VALUE;
}
//
map[i][j] = 1;
return (tempI * colLen + tempJ);
}
// 复制map
public int[][] copyMap(int[][] map) {
int[][] tempMap = new int[map.length][map[0].length];
for (int i = 0; i < tempMap.length; i++) {
tempMap[i] = map[i].clone();
}
return tempMap;
}
public static void main(String[] args) {
launch(args);
}
}
3.运行结果展示
搜索到最优路径为:32
尝试不同位置的起点和终点:
尝试加入多个终点:
以上就是完整代码啦!如果觉得感兴趣,欢迎点赞+关注,以后会继续更新相关方面的文章!