对比学习系列(四)--- BYOL

868 阅读6分钟

BYOL

MoCo,SimCLR等对比学习方法都依赖于负样本,BYOL不需要负样本也能在ImageNet上取得74.3%的top-1分类准确率。BYOL使用两个神经网络,online网络和targets网络。online网络的参数设为θ\theta,由三部分组成,编码器fθf_{\theta},projectorgθg_{\theta}和projectorqθq_{\theta};target网络与online网络有相同的架构,但是拥有不同的参数ξ\xi。target网络提供回归目标训练online网络,但是target网络的参数ξ\xi采用EMA(指数滑动平均)公式(与MoCo的动量更新公式相似)进行更新。公式为ξτξ+(1τ)θ\xi \leftarrow \tau \xi + \left( 1 - \tau \right)\theta

1645768169183.png

对于位于图像集DD中的一张图像xx,使用两种数据增广方式对图像xx进行扩增得到两种扩增视图vvv{v}',然后第一个扩增视图vv经过online网络得到表示yθ=fθ(v)y_{\theta}=f_{\theta}\left( v \right)和映射zθ=gθ(v)z_{\theta} = g_{\theta}\left( v \right),将第二个扩增视图v{v}'喂入target网络也分别得到表示yξ=fξ(v){y}'_{\xi}=f_{\xi}\left( {v}' \right)和映射输出zξ=gξ(v){z}'_{\xi} = g_{\xi}\left( {v}' \right)。online网络比target网络多了预测输出qθ(zθ)q_{\theta}\left( z_{\theta} \right),分别对qθ(zθ)q_{\theta}\left( z_{\theta} \right)zξ{z}'_{\xi}进行l2正则化得到qˉθ(zθ)=qθ(zθ)/qθ(zθ)2\bar{q}_{\theta}\left( z_{\theta} \right) = q_{\theta}\left( z_{\theta} \right) / \| q_{\theta}\left( z_{\theta} \right) \|_{2}zˉξ=zξ/zξ2{\bar{z}}'_{\xi} = {z}'_{\xi} / \| {z}'_{\xi} \|_{2},最后定义正则化的预测输出qˉθ\bar{q}_{\theta}和target网络的映射输出zˉξ{\bar{z}}'_{\xi} 的MSE函数作为损失函数。

Lθ,ξ=qˉθ(zθ)zˉξ22=22qθ(zθ),zˉξqθ(zθ)2zξ2L_{\theta, \xi} = \| \bar{q}_{\theta}\left( z_{\theta} \right) - {\bar{z}}'_{\xi} \|_{2}^{2} = 2 - 2 \cdot \frac{\left \langle q_{\theta}\left( z_{\theta} \right), {\bar{z}}'_{\xi} \right \rangle}{\| q_{\theta}\left( z_{\theta} \right) \|_{2} \cdot \| {z}'_{\xi} \|_{2} }

通过将v{v}'喂入online网络和将vv喂入targets网络得到对称损失函数L~θ,ξ\tilde{L}_{\theta, \xi}, BYOL总的损失函数为Lθ,ξBYOL=Lθ,ξ+L~θ,ξL_{\theta, \xi}^{BYOL} = L_{\theta, \xi} + \tilde{L}_{\theta, \xi}。模型训练过程中只有参数θ\theta根据梯度进行更新,参数ξ\xi依据EMA公式进行更新。

1646361045852.png

算法对比

SimCLR算法将某一输入图像xx的两个不同views喂入相同的网络得到两个projection输出zzz{z}',SimCLR对比损失的宗旨是:最大化输入图像xx的两个projection输出zzz{z}'之间的相似性,同时最小化与同一batch中的其他图像的projection之间的相似性。MoCo V2借鉴了SimCLR中的数据增广和projection层,减少batch size并提高性能。不同于SimCLR,MoCo V2中应用了动量更新,编码器分成了online 网络和momentum网络,online网络训练时通过SGD进行更新,momentum网络则是基于online网络参数的滑动平均进行更新。BYOL借鉴了MoCo的momentum网络,添加了一个MLP层qθq_{\theta}zz预测z{z}',p=qθ(z)p = q_{\theta}\left( z \right)。BYOL用正则化后的目标z{z}'和正则化后映射输出pp的L2损失函数替代了对比损失函数。BYOL中不需要负样本,所以BYOL不需要memory bank。

  • SimCLR中的MLP,每个线性层后都带有BN
  • MoCo V2中的MLP,不使用BN
  • BYOL中的MLP,只在第一个线性层之后带有BN

1646640393440.png

1648023718845.png

1646640529283.png

由于这三个算法中的MLP中BN的使用不同,导致使用MoCo中的MLP去训练BYOL时,模型的表现和模型随机的效果接近。参考[3]的博客作者发现:即使在损失函数中没有负样本,BN也隐含地在BYOL中引入对比学习。显然,这和BYOL算法不使用负样本的思想相悖。

为了驳回BYOL需要BN防止模型坍塌,是因为BN提供隐含的负样本,这一假设,BYOL论文作者又做了大量实验,探索在SimCLR和BYOL不同组件(编码器,SimCLR和BYOL的projector和BYOL的predictor)中使用不同正则化(BN or LN)和移除正则化的影响。实验结果如下图所示。从下图中可以看出,删除BYOL中的所有BN实例,BYOL的表现和随机模型效果差不多。这对BYOL是独有的,因为SimCLR在相同的情况下,表现良好。然而,仅在BYOL中的编码中加入BN实例就足以使得BYOL取得较好的性能。

当使用不用统计信息的LN替换BN时,BYOL模型坍塌,但是SimCLR却表现良好,BYOL和SimCLR模型性能的差异似乎更支持BYOL需要BN中提供的隐含负样本防止模型坍塌这一假设。然而,BN似乎主要作用于编码器。已知标准初始化会导致条件变差,BYOL在创建自己的目标时,更容易受到不正确初始化的影响,因此,论文作者假设BN在BYOL中的主要作用是补偿不正确的初始化,而不是BN隐含提供负样本的假设。

1648027155149.png

为了验证上述的假设,作者设计了不带BN的适当的初始化实验模拟BN对初始规模和动态训练的影响。1000个epoch之后,初始化实验的线性验证取得了65.7%的top-1准确率,相比于baseline模型的74.3%top-1准确率有所下降,但是也能证明BYOL在不使用BN的情况下,可以学习到非坍塌表示,而且也能证明BN的另一个作用是提供更好地初始规模和训练动态。使用GN代替所有的BN,并通过WS方案对卷积和线性参数进行权重标准化。由于GN和WS都不计算批处理统计信息,所以使用GN+WS的BYOL无法处理批处理中的元素,因此它同样无法实现批处理隐式对比机制。实验证明,没有BN的BYOL算法也能保持其大部分性能。

1648089104919.png

参考

  1. Bootstrap Your Own Latent A New Approach to Self-Supervised Learning
  2. BYOL works even without batch statistics
  3. Understanding Self-Supervised and Contrastive Learning with "Bootstrap Your Own Latent" (BYOL)

附录

指数滑动平均

EMA(Exponential Moving Average)指数滑动平均,是一种给与最近数据更高权重的平均方法,也被称为权重滑动平均(weighted moving average)。在深度学习优化过程中,θt\theta_{t}tt时刻的模型参数,vtv_{t}tt时刻的影子参数,在模型训练过程中,θt\theta_{t}的参数的更新与梯度下降有关θt=θn1αθJ(θ)\theta_{t} = \theta_{n-1} - \alpha \cdot \triangledown_{\theta} J\left( \theta\right)(其中α\alpha为学习率,J(θ)J\left( \theta\right)为损失函数关于参数θ\theta的偏导数),vtv_{t}的更新与θt\theta_{t}有关,公式如下:vt=βvt1+(1β)θtv_{t}= \beta \cdot v_{t-1} + \left( 1 - \beta \right) \cdot \theta_{t}

为了简单起见,将学习率设为1,t1t-1时刻的梯度为gt1g_{t-1},设θt=θt1gt1\theta_{t} = \theta_{t-1} - g_{t-1},对其进行变换,如下所示:

θn=θt1gt1=θt2gt1gt2  =θ1i=1n1gi\begin{matrix} \theta_{n} &= \theta_{t-1} - g_{t-1} \quad \quad \quad \\ &= \theta_{t-2} - g_{t-1} - g_{t-2} \; \\ & = \theta_{1} - \sum_{i=1}^{n-1}g_{i} \quad \quad \quad \end{matrix}

同样对影子权重vtv_{t}进行变换,如下所示:

vt=βvt1+(1β)θt  =β(βvt2(1β)θt1)+(1β)θt=βnv0+(1β)(θn+βθn1++βn1θn)=θ1i=1n1(1βni)gi\begin{matrix} v_{t} &= \beta \cdot v_{t-1} + \left( 1 - \beta \right) \cdot \theta_{t} \quad \quad \quad \quad \quad \quad \quad \quad \quad \; \\ &= \beta\left( \beta \cdot v_{t-2}\left( 1 - \beta \right) \cdot \theta_{t-1} \right) + \left( 1 - \beta \right) \cdot \theta_{t} \quad \quad \\ &= \beta^{n}v_{0} + \left( 1 - \beta \right) \left( \theta_{n} + \beta \theta_{n-1} + \cdots + \beta^{n-1} \theta_{n} \right) \\ & = \theta_{1} - \sum_{i=1}^{n-1} \left( 1- \beta^{n-i}\right)g_{i} \quad \quad \quad \quad \quad \quad \quad \quad \quad \end{matrix}

对于tt时刻的参数θt\theta_{t}vtv_{t}vtv_{t}θt\theta_{t}更新地更缓慢。

Group normalization

对于维度为(N,H,W,C)\left( N, H, W, C \right)的张量XX,GN首先将通道切分成GG个大小相同的组,然后在大小为(1,H,W,C/G)\left( 1, H, W, C/G \right)不相交的切片上计算平均值和标准差从而规范化激活。当G=1G=1时,GN相当于LN(Layer Normalization);当G=CG=C时,GN相当于IN(Instance Normalization)。重要的是,GN在每个批次元素上独立运行,因而它不依赖批次统计信息。

Weight standardization

WS使用权重统计信息规范化每个激活对应的权重。权重矩阵WW中每一行被正则化得到一个新的权重矩阵W^\hat{W}。只有归一化权重W^\hat{W}用于计算卷积输出,但是损失根据非标准化权重WW进行区分。

W^i,j=Wi,jμiσi,withμi=1Ij=1IWi,jandσi=ε+1Ij=1I(Wi,jμ)2\begin{matrix} \hat{W}_{i,j} = \frac{W_{i,j} - \mu_{i}}{\sigma_{i}} , \\ with \quad \mu_{i} = \frac{1}{I} \sum_{j=1}^{I} W_{i,j} \\ and \quad \sigma_{i}= \sqrt{ \varepsilon + \frac{1}{I} \sum_{j=1}^{I} \left( W_{i,j} - \mu \right)^{2}} \end{matrix}

其中II是输入维度,ε=104\varepsilon = 10^{-4}