CURL 论文阅读

113 阅读6分钟

0.论文信息和个人感想

1.背景信息

  • 论文背景: 运用深度神经网络的表达能力和长期的分值分配能力,已经可以实现从高维观测 (如像素) 中执行复杂控制任务的智能代理。然而,从原始像素进行强化学习的样本效率较低,而从物理状态特征进行学习的样本效率要高得多。因此,提高从高维观测进行强化学习的样本效率对于发展智能自主代理至关重要。
  • 已有的方案: 为了解决深度强化学习算法的样本效率问题,已经提出了许多方法。这些方法可以分为两类研究方向:
    • (i) 在代理的感知观测上进行辅助任务。
    • (ii) 预测未来的世界模型。虽然这些方法在加速模型无关强化学习方法的学习过程方面取得了一定的成功,但构建显式的预测模型或引入额外的超参数等因素会增加算法的复杂性。
  • 本文的动机: 本研究的动机是通过学习有用的语义表示来提高样本效率。最近几年,自监督表示学习在语言和视觉领域取得了巨大进展。这些目标揭示的表示改善了任何监督学习系统的性能,特别是在下游任务的标记数据非常有限的情况下。本研究从对比预训练的成功中获得启示,提出了一种对比学习方法,旨在改善能够从在线交互中有效学习控制的代理。

图 1. Contrastive Unsupervised Representations for Reinforcement Learning (CURL) 结合了实例对比学习和强化学习。CURL通过使用对比损失来确保观测 oo 的数据增强版本 oqo_qoko_k 的嵌入匹配,从而训练一个视觉表示编码器。查询观测值 oqo_q 被视为锚点,而关键观测值 oko_k 包含正值和负值,它们都是从RL更新采样的小批量数据中构造的。这些键是用查询编码器的动量平均版本来进行编码。强化学习策略和 (或) 值函数建立在查询编码器之上,该编码器与对比和强化学习目标联合训练。CURL 是一个通用框架,可以插入到任何依赖于从高维图像中学习表示的强化学习算法中。

2.CURL 对应的方法论

  CURL 是将对比学习与强化学习相结合的通用框架。原则上,可以在 CURL 的管道中使用任何强化学习算法,无论是策略型还是非策略型。我们使用广泛采用的Soft Actor Critic (SAC) 进行连续控制基准 (DM Control) 和 Rainbow DQN 用于离散控制基准 (Atari)。下面对于他们与对比学习进行回顾。

2.1 SAC

  SAC 是一种非策略型的强化学习算法,它优化了随机策略以最大化预期轨迹回报。与其他最先进的端到端的强化学习算法一样,SAC 在从状态观察解决任务时是很有效的,但无法从像素中学习有效的策略。SAC 是一种 actor-critic method (演员-评论家方法) 方法,它学习策略 πψ\pi_{\psi} 和评论家 Qϕ1Q_{\phi_1}Qϕ2Q_{\phi_2}。通过最小化贝尔曼误差来学习参数 ϕi\phi_i :

L(ϕi,B)=EtB[(Qϕi(o,a)(r+γ(1d)T))2]\mathcal{L}\left(\phi_i, \mathcal{B}\right)=\mathbb{E}_{t \sim \mathcal{B}}\left[\left(Q_{\phi_i}(o, a)-(r+\gamma(1-d) \mathcal{T})\right)^2\right]

  其中 t=(o,a,o,r,d)t=\left(o, a, o^{\prime}, r, d\right) 是一个元组,观察 oo,行动 aa,奖励 rr 和完成信号 ddB\mathcal{B} 是重放缓冲区,T\mathcal{T} 是目标,其定义为 :

T=(mini=1,2Qϕi(o,a)αlogπψ(ao))\mathcal{T}=\left(\min _{i=1,2} Q_{\phi_i}^*\left(o^{\prime}, a^{\prime}\right)-\alpha \log \pi_\psi\left(a^{\prime} \mid o^{\prime}\right)\right)

  在上面的目标方程中,QϕiQ^*_{\phi_i} 表示 QϕiQ_{\phi_i} 的参数的指数移动平均 (EMA)。使用 EMA 的经验表明可以提高非策略型强化学习算法的训练稳定性。参数 α\alpha 是一个正熵系数,它决定了熵最大化对值函数优化的优先级。

  虽然评论家由 QϕiQ_{\phi_i} 给出,但演员从策略 πψ\pi_{\psi} 中采样动作,并通过最大化其动作的预期回报来训练,如下所示 :

L(ψ)=Eaπ[Qπ(o,a)αlogπψ(ao)]\mathcal{L}(\psi)=\mathbb{E}_{a \sim \pi}\left[Q^\pi(o, a)-\alpha \log \pi_\psi(a \mid o)\right]

  其中动作是从策略 aψ(o,ξ)tanh(μψ(o)+σψ(o)ξ)a_\psi(o, \xi) \sim \tanh \left(\mu_\psi(o)+\sigma_\psi(o) \odot \xi\right) 中随机采样的,ξN(0,I)\xi \sim \mathcal{N}(0, I) 是标准的归一化噪声向量。

2.2 Rainbow DQN

  Rainbow DQN 是基于 Q-learning 发展出的 DQN 的迭代版本,在此不多赘述,后面我会专门再补一篇进行相应的介绍。

2.3 对比学习

  CURL 的另一个关键组成部分是使用对比无监督学习学习高维数据的丰富表示的能力。对比学习可以理解为学习可微字典查找任务。给定一个查询 qq 和键 K={k0,k1,}\mathbb{K}=\left\{k_0, k_1, \ldots\right\} 和显式已知的 K\mathbb{K} 分区 (关于 q) P(K)=({k+},K\{k+})P(\mathbb{K})=\left(\left\{k_{+}\right\}, \mathbb{K} \backslash\left\{k_{+}\right\}\right),对比学习的目标是确保 qqk+k_+ 匹配相对多于 K\{k+}\mathbb{K} \backslash\left\{k_{+}\right\} 中的任何键。q,K,k+q, \mathbb{K}, k_+K\{k+}\mathbb{K} \backslash\left\{k_{+}\right\} 在对比学习方面也称为锚、目标、正样本、负样本。锚点和目标之间的相似性最好用点积 (qTkq^T k) 或双线性乘积 (qTWkq^T W k) 建模,同时欧几里得距离等其他形式也很常见。为了学习尊重这些相似关系的嵌入,提出了 InfoNCE 损失 :

Lq=logexp(qTWk+)exp(qTWk+)+i=0K1exp(qTWki)\mathcal{L}_q=\log \frac{\exp \left(q^T W k_{+}\right)}{\exp \left(q^T W k_{+}\right)+\sum_{i=0}^{K-1} \exp \left(q^T W k_i\right)}

  上面损失函数可以解释为标签为 $k_+# 的 K 路 softmax 分类器的对数损失。

3.CURL 对应的细节补充

图 2. CURL 的架构 : 从重放缓冲区中采样一批转换数据。然后对观测值进行两次数据扩充,形成查询和关键观测值,然后分别用查询编码器和关键编码器对它们进行编码。查询被传递给强化学习算法,而查询键对被传递给对比学习目标。在梯度更新步骤中,只更新查询编码器。关键编码器权重是与MoCo类似的查询权重的移动平均(EMA)。

3.1 架构概述

  CURL 使用与 SimCLR、MoCo 和 CPC 相似的实例判别。大多数深度强化学习架构使用一堆时间连续的帧作为输入进行操作。因此,与单个图像实例相比,跨帧堆栈执行实例区分。我们对类似于 MoCo 的目标使用动量编码程序,我们发现它对强化学习表现更好。最后,对于 InfoNCE 评分函数,我们使用类似于 CPC 的双线性内积,我们发现它比 MoCo 和 SimCLR 中使用的单位范数向量乘积效果更好。对比表示与强化学习算法联合训练,潜在代码接收来自对比目标和 Q 函数的梯度。架构的概述如图 2 所示。

3.2 区别的目标

  对比表示学习的一个关键组成部分是相对于锚点的正样本和负样本的选择。基于对比预测编码 (CPC) 的管道使用由精心选择的空间偏移分隔的图像块组用于锚点和正样本,而负样本来自图像中的其他补丁和其他图像。虽然补丁是将空间和时间区分结合在一起的强大方法,但它们引入了额外的超参数和架构设计选择,这些选择可能难以适应新的问题。SimCL 和 Moco 选择了一个更简单的设计,其中没有补丁提取。

  区分转换后的图像实例,而不是同一图像中的图像块,使用 InfoNCE 损失函数优化了更简单的实例区分目标,并且需要最少的体系结构调整。出于两个原因,最好在强化学习设置中选择更简单的判别目标。首先,考虑到强化学习算法的脆弱性,复杂的划分与区分可能会破坏强化学习目标的稳定。其次,由于强化学习算法是在动态生成的数据集上训练的,因此复杂的判别目标可能会显着增加挂钟训练时间。因此,CURL 使用实例判别而不是补丁判别。人们可以将 SimCLR 和 MoCo 等对比实例判别设置视为最大化图像与其增强版本之间的互信息。

3.3 Query-Key 对生成

  与图像设置中的实例区分类似,锚点和正样本观察是同一图像的两个不同增强,而负样本来自其他图像。CURL 主要依赖于随机裁剪数据增强,其中从原始渲染中裁剪出一个随机方块。

  强化学习和计算机视觉设置之间的一个显著区别是,无模型强化学习算法从像素操作的实例不仅仅是单个图像,而是一堆帧。例如,通常在 Atari 实验中提供 4 帧的堆栈和 DMControl 中的 3 帧的堆栈。这样,对帧堆栈执行实例区分允许 CURL 学习空间和时间判别特征。我们在批次上应用随机增强,但在每个帧堆栈中始终保持一致,以保留有关观察时间结构的信息。增强过程如图3所示。

图 3. 直观地说明了使用随机随机随机裁剪生成锚点及其正的过程。我们的裁剪纵横比为 0.84,即我们从 100×100100\times100 模拟渲染图像中裁剪 84×8484\times84 张图像。在堆栈中的所有帧中应用相同的随机裁剪坐标确保了时间一致的空间抖动。

3.4 相似性判别

  判别目标的另一个决定因素是用于衡量查询键对之间的一致性的内积。CURL 采用双线性内积 sim(q,k)=qTWk\operatorname{sim}(q, k)=q^T W k ,其中 $W 是学习参数矩阵。我们发现这种相似性度量优于最近最先进的对比学习方法中使用的归一化点积。

3.5 使用动量的目标编码

  在 CURL 中使用对比学习的动机是训练从高维像素映射到更多语义潜在的编码器。InfoNCE 是一种无监督损失函数,它学习编码器 fqf_qfkf_k,将原始锚点 (查询) xqx_q 和目标 (键) xkx_k 映射到潜在 q=fq(xq)q=f_q(x_q)k=fk(xk)k=f_k(x_k),在其上应用相似性点积。在锚点映射和目标映射之间共享相同的编码器是很常见的,即有 fq=fkf_q=f_k

  从将对比学习视为在高维实体上构建可微字典查找的角度来看,增加字典的大小并丰富否定集有助于学习丰富的表示。He等人提出了动量对比 (MoCo),它使用查询编码器 fqf_q 的指数移动平均 (动量平均) 版本来编码 K\mathbb{K} 中的键。给定由 θq\theta_q 参数化的 fqf_q 和由 θk\theta_k 参数化的 fkf_k,MoCo 执行更新 θk=mθk+(1m)θq\theta_k=m \theta_k+(1-m) \theta_q,并使用 SG(fk(xk))\operatorname{SG}\left(f_k\left(x_k\right)\right) 对任何目标 xkx_k 进行编码。

  CURL 在对比学习期间将帧堆栈实例判别与目标的动量编码相结合,强化学习在编码器特征之上执行。

3.6 CURL 对比学习的 pytorch 版本伪代码

# f_q, f_k: encoder networks for anchor 
# (query) and target (keys) respectively. 
# loader: minibatch sampler from ReplayBuffer 
# B-batch_size, C-channels, H,W-spatial_dims 
# x : shape : [B, C, H, W] 
# C = c * num_frames; c=3 (R/G/B) or 1 (gray)
# m: momentum, e.g. 0.95 
# z_dim: latent dimension 
f_k.params = f_q.params 
W = rand(z_dim, z_dim) # bilinear product. 
for x in loader: # load minibatch from buffer 
  x_q = aug(x) # random augmentation 
  x_k = aug(x) # different random augmentation 
  z_q = f_q.forward(x_q) 
  z_k = f_k.forward(x_k) 
  z_k = z_k.detach() # stop gradient 
  proj_k = matmul(W, z_k.T) # bilinear product 
  logits = matmul(z_q, proj_k) # B x B 
  # subtract max from logits for stability 
  logits = logits - max(logits, axis=1) 
  labels = arange(logits.shape[0]) 
  loss = CrossEntropyLoss(logits, labels) 
  loss.backward() update(f_q.params) # Adam 
  update(W) # Adam 
  f_k.params = m*f_k.params+(1-m)*f_q.params

4.实验

  略

5.读后感想

  感觉这篇工作尝试将对比学习引入强化学习中来对原有的强化学习进行改进,借鉴对比学习的几篇经典工作,其对于 oko_koqo_q 进行对比学习,并只使用 oqo_q 来进行强化学习以提升性能。