策略梯度 PPO
相关代码已开源至github:github.com/sdycodes/RL…
基本理论
回顾REINFORCE,使用神经网络拟合policy,优化目标为,并进行了相应的改良。
当时我们提出问题,就是更新的时候必须一个完整的episode才能做,因为它用来采样数据的policy和他更新的policy是一样的。如果中途更新了策略,那采样数据的策略和当前策略就不一致,这些数据就没法再用了。所以这样效率很低。
于是提问,能不能让更新的策略和采样不是一个策略?这就需要看看能否对优化目标进行一些变形。
第1行是期望的定义,第2行是恒等变换,第3行是期望定义的逆变换。注意加了角标和表示使用哪个policy在采样。 最后一行分母上带撇号表示在另外一个策略下进行采样的概率。但是要注意虽然policy不一样但分子和分母的和是一样的,因为第2行恒等变换的时候是一个。
这个东西很复杂,分子分母上有一堆概率项并且还有连乘。 直观看关于环境的概率项似乎都可以约分,这个确实合理,但保险起见一般还是会显式地假设:不同策略下环境的表现都差不多,所以直接把和分母上对应项约分了。
接下来优化目标变成了
现在只剩连乘比较麻烦了。这就体现出off-policy的一个优势了,那就是我们不需要非得把一个episode看作一个整体去考虑。因为现在采样的策略和我们更新的策略不一样,本来我们更新的时候带入的和就不是我们这个策略得到的,所以分batch或者单条数据优化policy没有什么问题。因此我们再一次简化优化目标
现在来看优化目标就很简单了,接下求梯度,这次和REINFORCE不一样,因为期望里面已经显式的有一个了,那就不用凑log了。
然后用均值代替期望,把REINFORCE那些梯度赋权重的思想引入进来,优化目标就变成
注意现在我们没有显式的写出对episode和每个step的求和过程,也是说明没有必要再一个episode看成整体去更新了。基线b写成了关于状态的函数,这个之前已经说明,很多时候会估计一个函数作为使用。PPO就是这样干的。两个的比值被称为重要性系数,采样过程叫重要性采样。
接下来考虑最后一个问题,那就是采样和优化的策略不一样,看上去只有好处没有坏处,那是不可能的,所以这么做一定有代价,代价是什么呢?这时候就要放出李宏毅机器学习课上使用的经典例子:
假设x的实际分布是满足p的,那么的期望应该是负数,因为p的绝大部分概率都集中左侧,左侧的f(x)为负数。但如果此时使用重要性采样,用q去采样,那采出来的结果肯定是正数,再加上重要性系数是个概率比值,也是正数,所以最后结果连符号都和原来不一样了,说明重要性采样有时候会非常不准确。
严谨的推导方差和直觉一致,结论都是如果p和q相差太大,那估计会很不准确。回到我们的优化目标上,就是不希望优化策略和采样策略差太多。一个很直观的方法是使用KL散度量化这个区别,并作为惩罚项加入优化目标中。
因为是最大化优化目标所以惩罚项是在原有基础上减。
这就是PPO1算法。还有另外一种思想来解决这个问题,称之为clamp,这种情况下优化目标为
看着很复杂,其实很简单。先看第2行,clamp函数长这样
相当于对重要性采样的大小做限制。
然后第1行再取min。如果是负数,此时clamp的下限发挥作用,那说明你想降低这个概率没问题,但你想把概率压得太低离采样策略差距太大,这我不鼓励,最多就是。
如果是正数时上限发挥作用,原理类似。
重新审视,如果用来代替,前面的可以看作某个特定动作的期望累积收益,那他也可以表示为,这个被称为advantage function。
也很直观说明了在某个状态下,采取动作能够比预计的平均收益高出多少。
实现细节
这个实现需要注意几个点
- 1 关于期望累积收益。现在一方面可以用蒙特卡洛的方法估计。和以前一样倒序求和。但是因为用critic网络了,所以也可以用时序差分的方法。引入时序差分的一个优势是不必等到episode结束再倒序求和,可以在收集一个batch就更新,计算累计收益的时候用critic网络估计后面的reward和。
- 2 采样模型的问题。采样模型一般是定期从policy复制而来。一条episode中,policy model可以更新多次,采样模型也可以随之更新。但要注意采样模型一旦更新,之前采样的数据就不能再用了,不然重要性系数计算就不对。
- 3 critic的更新问题,一条episode内可以更新多次,但是如果更新了,advantage的计算就不对了,所以应该在采样的时候提前计算好。
- 4 计算importance ratio需要求概率的比值,离散动作还比较容易,而对于连续动作,可以使用torch.distributional.Normal.log_prob来快速计算。
policy网络的结构、choose action和REINFORCE一样,不再赘述。 数据采集过程发生变化,主要是因为现在可以在episode内batch base更新了。 update函数的实现细节如下
def update(self, states: Union[List[np.array], np.array], actions: Union[List[np.array], np.array],
rewards: Union[List[np.array], np.array], next_states: Union[List[np.array], np.array],
dones: Union[List[np.array], np.array], infos: List[Any]) -> Any:
batch_size = len(states)
old_state_values = np.array([each[0] for each in infos])
sample_prob_dist = np.array([each[1] for each in infos])
for i in range(0, batch_size, 200):
left, right = i, min(i + batch_size, batch_size)
batch_rewards = torch.Tensor(np.array(rewards[left:right])).to(self.device).view(-1)
batch_states = torch.Tensor(np.array(states[left: right])).to(self.device)
batch_actions = torch.Tensor(np.array(actions[left: right])).view(-1, 1).to(self.device, dtype=torch.long)
# calculate advantage
batch_old_state_values = torch.Tensor(old_state_values[left: right]).to(self.device).view(-1)
advantages = batch_rewards - batch_old_state_values
# calculate critic loss
critic_loss = (self.critic(batch_states) - batch_rewards).square().mean()
# calculate actor loss
# calculate importance ratio
probs_policy_single = self.policy.log_probs(batch_states, batch_actions).exp().view(-1)
probs_policy = self.policy(batch_states)
batch_sample_prob_dist = torch.Tensor(sample_prob_dist[left: right]).to(self.device)
probs_sample = batch_sample_prob_dist.gather(dim=1, index=batch_actions).view(-1)
important_ratio = probs_policy_single / probs_sample
raw_actor_targets = (important_ratio * advantages)
if self.strategy == "ppo1":
actor_target = raw_actor_targets.mean() - \
self.beta * torch.kl_div(probs_policy.log(), batch_sample_prob_dist).sum(1).mean()
else:
# ppo2 continuous force to use ppo2
clamp_ratio = torch.clamp(important_ratio, 1 - self.eps, 1 + self.eps)
actor_target = torch.min(raw_actor_targets, clamp_ratio * advantages).mean()
loss: torch.Tensor = -actor_target + critic_loss
self.critic_optimizer.zero_grad()
self.policy_optimizer.zero_grad()
loss.backward()
self.policy_optimizer.step()
self.critic_optimizer.step()
# finish update reset sampler
self.reset_sampler()
进阶到连续动作
和REINFORCE也是一样的,不过要注意pendulum有个小坑,就是这个游戏其实是无限长的,无限长问题计算累积收益即使是最后一步,也需要加上critic的估计值。鉴于PPO能够在一条episode内多次更新的优势,PPO处理连续动作要好很多,可以看到pendulum可以很稳定的倒立。
相关代码已开源至github:github.com/sdycodes/RL…
实验结果如图