深度强化学习(DRL)算法 5 —— Twin Delayed Deep Deterministic Policy Gradient (TD3)

58 阅读2分钟

回顾

深度强化学习(DRL)算法 4 —— Deep Deterministic Policy Gradient (DDPG) - 掘金 (juejin.cn) 文章里提到了 DDPG 存在的三个缺点。 高估问题target Q 网络和参数更新顺序问题DPG 的方式增大了方差(可以采取类似期望 Sarsa 的方式减小方差)

算法描述

我们先来看问题 1 和 3,TD3 提出了 Clipped Double Q-Learning 和 Target Policy Smoothing Regularization 来解决,针对问题 2,提出了 Delayed Policy Updates。

Clipped Double Q-Learning

我们知道 DDPG 里面的 target Q 网络是:

r+γNNtarget(st+1,NNaction_target(st+1,θ),w)r + \gamma NN_{target}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta), w) NNtargetNN_{target} 预测是有可能高估的,一个网络可能高估,那么我们可以来两个(更多个,理论上效果更好),取最小的那个作为预测,就可以一定程度上缓解高估问题。 target Q1: r+γNNtargetq1(st+1,NNaction_target(st+1,θ),w1)r + \gamma NN_{target_{q1}}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta), w1)

target Q2: r+γNNtargetq2(st+1,NNaction_target(st+1,θ),w2)r + \gamma NN_{target_{q2}}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta), w2)

min(target Q1, target Q2) 作为最终的 target_Q。

Target Policy Smoothing Regularization

之前的文章 深度强化学习(DRL)算法 附录 3 —— 蒙特卡洛方法(MC)和时序差分(TD) - 掘金 (juejin.cn) 介绍了期望 Sarsa 的方法,是一种比 Sarsa(DDPG 是本质上是 Sarsa)更稳定的方法,因为取了期望,相当于考虑了更多的动作,所以也可以用到 DDPG 里。

lossq=(r+γNNtarget(st+1,NNaction_target(st+1,θ),w)tdtargetNN(st,NNaction(st+1,θ)+ϵ,w))2loss_q = (\underbrace{r + \gamma NN_{target}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta), w)}_{td-target} - NN(s_{t}, NN_{action}(s_{t+1}, \theta') + \epsilon , w'))^{2}

lossq 实际上是平均平方误差(MSE,之前只是以一个采样为例子,所以没求平均),所以引入探索性后,就类似期望 Sarsa 那样。 我们知道 DDPG 动作网络也引入了随机扰动 σ\sigma

NN(st,NNaction(st+1,θ)+σ)NN(s_{t}, NN_{action}(s_{t+1}, \theta') + \sigma)

我们可以直接把这个扰动加到 target_Q 里面,来增强探索性吗? 当然不可以,这两个随机扰动的目的不同: 前者是为了经验回放采样数据的时候,获得更多的探索性,所以可以随心所欲的探索。 而后者是为了减小 target_Q 预测的方差,所以不应该随心所欲的探索,只是在当前状态动作附近的探索,因而应该给 σ\sigma 加上限制:

ϵclip(N(0,σ),c,c)\epsilon' \sim \operatorname{clip}(\mathcal{N}(0, \sigma),-c, c)

所以更新为: target Q1: r+γNNtargetq1(st+1,NNaction_target(st+1,θ)+ϵ,w1)r + \gamma NN_{target_{q1}}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta) + \epsilon', w1)

targetQ2: r+γNNtargetq2(st+1,NNaction_target(st+1,θ)+ϵ,w2)r + \gamma NN_{target_{q2}}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta) + \epsilon', w2)

lossq1=(r+γNNtargetq1(st+1,NNaction_target(st+1,θ),w1)tdtargetNN(st,NNaction(st+1,θ)+ϵ,w1))2loss_{q1} = (\underbrace{r + \gamma NN_{target_{q1}}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta), w1)}_{td-target} - NN(s_{t}, NN_{action}(s_{t+1}, \theta') + \epsilon , w1'))^{2}

lossq2=(r+γNNtargetq2(st+1,NNaction_target(st+1,θ),w2)tdtargetNN(st,NNaction(st+1,θ)+ϵ,w2))2loss_{q2} = (\underbrace{r + \gamma NN_{target_{q2}}(s_{t+1}, NN_{action\_target}(s_{t+1}, \theta), w2)}_{td-target} - NN(s_{t}, NN_{action}(s_{t+1}, \theta') + \epsilon , w2'))^{2}

lossq=lossq1+lossq2loss_q = loss_{q1} + loss_{q2}

Delayed Policy Updates

这个很简单,看名字就能猜出来干了什么,就是 ω\omegaθ\theta 不同时更新,且 θ\theta 更新在 ω\omega 更新几轮后,如果同时更新,相当于每次更新后,对相同的 state 产生了不同的 q 值,想当于引入新的残差 q(s, a) - q(s,a')。所以为了减小这种误差,θ\theta 更新在 ω\omega 更新几轮后。所以 TD3 的软更新变成了(这里和文章不一样,文章说的 Delayed Policy Updates 是包括参数软更新也延迟了,我这里参数软更新没有延迟,需要做实验验证一下,理论上应该差距不大) :

ω1τω1+1τω1\omega1 \leftarrow \tau\omega1' + (1-\tau)\omega1

ω2τω2+1τω2\omega2 \leftarrow \tau\omega2' + (1-\tau)\omega2

延迟更新:

θτθ+1τθ\theta \leftarrow \tau\theta' + (1-\tau)\theta

缺点

ϵ\epsilon是高斯分布,和 Q 没有关系,而且 Q 使用的是确定性策略,实际上针对连续动作空间,采用随机策略,并且随机扰动和 Q 相关,才是更合理的,因为不但增加了探索性,而且探索性和 Q 的大小是相关的,Q 越大随机策略的探索性应该越小。

改进

针对这一缺点,下篇文章对 SAC 进行介绍,感谢阅读。

参考

arxiv.org/pdf/1802.09…