声明
对深度强化学习感兴趣的同学,强烈推荐王树森的视频课程!!
本篇笔记参考了另一篇博客和百度飞桨强化学习公开课,欢迎访问。
强化学习初印象
Reinforcement learning 是机器学习中的一个领域,强调如何基于环境而行动,以获得最大化的预期利益。
核心思想:智能体agent 在 环境environment 中学习,根据 环境的状态state(或者observation),执行 动作action,并根据环境的 反馈reward(奖励)来指导更好的动作。
强化学习的经典算法: Q-learning、Sarsa、DQN、Policy Gradient、A3C、DDPG、PPO
环境分类:离散型控制场景(输出控制可数)、连续控制场景(输出动作不可数)。
表格型方法求解RL
Sarsa
state-action-reward-state'-action' 的缩写。目的是学习特定的state下,特定action的价值Q,最终建立和优化一个Q表格,以state为行,action为列,根据与环境交互得到的reward来更新Q表格,更新公式为:
其中,α为学习率,γ为奖励性衰变系数
Q-learning
Q-learning也是采用Q表格的方式存储Q值(状态动作价值),决策部分与Sarsa是一样的,采用ε-greedy方式增加探索。
Q-learning跟Sarsa不一样的地方是更新Q表格的方式。
- Sarsa是on-policy的更新方式,先做出动作再更新。
- Q-learning是off-policy的更新方式,更新learn()时无需获取下一步实际做出的动作next_action,并假设下一步动作是取最大Q值的动作。
Q-learning的更新公式为:
其中,α为学习率,γ为奖励性衰变系数
Q-learning算法流程
案例
agent:
QL.py
import numpy as np
import pandas as pd
class QL:
def __init__(self, actions, learning_rate=0.05, reward_decay=0.9, e_greedy=0.9):
self.actions = actions #初始化可以进行的各种行为,传入为列表
self.lr = learning_rate #学习率,用于更新Q_table的值
self.gamma = reward_decay #当没有到达终点时,下一环境对当前环境的影响
self.epsilon = e_greedy #随机选择几率为1-e_greedy,当处于e_greedy内时,不随机选择。
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) #生成q_table,列向量为columns
def choose_action(self,observation):
self.check_observation(observation) #检测是否到达过这个点,如果没到达过,在Q表中增加这个节点
action_list = self.q_table.loc[observation,:] #取出当前observation所在的不同方向
if(np.random.uniform() < self.epsilon): #如果在epsilon几率内
action = np.random.choice(action_list[action_list == np.max(action_list)].index) #选出当前observation中Q值最大的方向
else:
action = np.random.choice(self.actions) #如果不在epsilon内,则随机选择一个动作
return action #返回应当做的action
def learn(self,observation_now,action,score,observation_after,done):
self.check_observation(observation_after) #检查是否存在下一环境对应的方向状态
q_predict = self.q_table.loc[observation_now,action] #获得当前状态下,当前所作动作所对应的预测得分
if done:
q_target = score #如果完成了则q_target为下一个环境的实际情况得分,本例子中此时score为1
else:
q_target = score + self.gamma * self.q_table.loc[observation_after, :].max() #如果未完成则取下一个环境若干个动作中的最大得分作为这个环境的价值传递给当前环境
#根据所处的当前环境对各个动作的预测得分和下一步的环境的实际情况更新当前环境的q表
self.q_table.loc[observation_now, action] += self.lr * (q_target - q_predict)
def check_observation(self,observation):
if observation not in self.q_table.index: #如果不存在
self.q_table = self.q_table.append( #则通过series函数生成新的一列
pd.Series(
[0]*len(self.actions),
index=self.actions,
name=observation,)
)
Envirnment
Env.py
import numpy as np
import pandas as pd
import time
class Env:
def __init__(self,column,maze_column):
self.column = column #表示地图的长度
self.maze_column = maze_column - 1 #宝藏所在的位置
self.x = 0 #初始化x
self.map = np.arange(column) #给予每个地点一个标号
self.count = 0 #用于技术一共走了多少步
def draw(self):
a = []
for j in range(self.column) : #更新图画
if j == self.x:
a.append('o')
elif j == self.maze_column:
a.append('m')
else:
a.append('_')
interaction = ''.join(a)
print('\r{}'.format(interaction),end = '')
def get_observation(self):
return self.map[self.x] #返回现在在所
def get_terminal(self):
if self.x == self.maze_column: #如果得到了宝藏,则返回已经完成
done = True
else:
done = False
return done
def update_place(self,action):
self.count += 1 #更新的时候表示已经走了一步
if action == 'right':
if self.x < self.column - 1:
self.x += 1
elif action == 'left': #left
if self.x > 0:
self.x -= 1
def get_target(self,action):
if action == 'right': #获得下一步的环境的实际情况
if self.x + 1 == self.maze_column:
score = 1
pre_done = True
else:
score = 0
pre_done = False
return self.map[self.x + 1],score,pre_done
elif action == 'left': #left
if self.x - 1 == self.maze_column:
score = 1
pre_done = Ture
else:
score = 0
pre_done = False
return self.map[self.x - 1],score,pre_done
def retry(self): #初始化
self.x = 0
self.count = 0
执行主程序
run_this.py
from Env import Env
from QL import QL
import time
LONG = 6 #总长度为6
MAZE_PLACE = 6 #宝藏在第六位
TIMES = 15 #进行15次迭代
people = QL(['left','right']) #生成QLearn主体的对象,包含left和right
site = Env(LONG,MAZE_PLACE) #生成测试环境
for episode in range(TIMES):
state = site.get_observation() #观察初始环境
site.draw() #生成图像
time.sleep(0.3) #暂停
while(1):
done = site.get_terminal() #判断当前环境是否到达最后
if done: #如果到达,则初始化
interaction = '\n第%s次世代,共使用步数:%s。'%(episode+1 ,site.count)
print(interaction)
# print(people.q_table)
site.retry()
time.sleep(2)
break
action = people.choose_action(state) #获得下一步方向
state_after,score,pre_done = site.get_target(action) #获得下一步的环境的实际情况
people.learn(state,action,score,state_after,pre_done) #根据所处的当前环境对各个动作的预测得分和下一步的环境的实际情况更新当前环境的q表
site.update_place(action) #更新位置
state = state_after #状态更新
site.draw() #更新画布
time.sleep(0.3)
print(people.q_table)