对比学习系列(三)---SimCLR

506 阅读2分钟

SimCLR通过隐藏空间的对比损失最大化相同数据在不同增广下的一致性来学习表达。SimCLR框架有四个主要的组件,分别是:数据增广,encode网络,projection head网络和对比学习函数。

1646293722316.png

对于数据xx,从同一个数据增广族中抽取两个独立的数据增广算子(tTt \sim TtT{t}' \sim T),以获得两个相关的视图x^i\hat{x}_{i}x^j\hat{x}_{j}x^i\hat{x}_{i}x^j\hat{x}_{j}是一对正样本,然后一个神经网络编码器f()f\left( \cdot \right)从增广的数据中提取特征hi=f(x^i),hj=f(x^j),h_{i}=f\left( \hat{x}_{i} \right), h_{j}=f\left( \hat{x}_{j} \right),。再然后一个小的神经网络project head g()g\left( \cdot \right)将特征映射到对比损失的空间。project head采用带有一个隐含层的MLP获取zi=g(hi)=W(2)σ(W(1)hi)z_{i} = g\left( h_{i} \right) = W^{\left( 2 \right)} \sigma \left( W^{\left( 1 \right)} h_{i}\right)

对于包含一对正样本x^i\hat{x}_{i}x^j\hat{x}_{j}的集合{x^k}\{ \hat{x}_{k} \},对比预测任务目的是对于给定的x^i\hat{x}_{i}{x^}ki\{ \hat{x} \}_{k \neq i}中识别出x^j\hat{x}_{j}。随机挑选NN个样本组成一个minibatch,这个minibatch中则有2N2N个数据样本,将其他2(N1)2\left( N - 1\right)个扩增的样本作为这个minibatch中的负样本,设sim(u,v)=uTv/uvsim\left( u, v\right) = u^{T}v / \| u\| \| v\|表示l2l_{2}正则化后你的uuvv的点积,那么对一对正样本(i,j)\left( i, j \right),损失函数如下定义:

li,j=logexp(sim(zi,zj)/τ)k=12N1[ki]exp(sim(zi,zk)/τ)l_{i,j} = - log \frac{exp\left( sim \left( z_{i}, z_{j}\right) / \tau \right)}{\sum_{k=1}^{2N} \mathbb{1}_{[ k \neq i]} exp\left( sim \left( z_{i}, z_{k}\right) / \tau \right)}

最后的损失函数计算一个minibatch中的所有的正样本对,包括(i,j)\left( i, j \right)(j,i)\left( j,i \right)。下面是SimCLR的伪代码。从伪代码中可以看出,编码器f()f\left( \cdot \right)和project head g()g\left( \cdot \right) 在训练时都会被更新参数,但是只有编码器f()f\left( \cdot \right)用于下游任务。

1646117078379.png

simCLR不采用memory bank的形式进行训练,而是加大batchsize,bacth size为8192,对于每一个正样本,将会有16382张负样本实例。增大batch size其实相当于每个minibatch时动态生成一个memory bank。论文中发现使用标准的SGD/Momentum,大batch size训练时是不稳定的,论文中采用LARS优化器。

参考

  1. The Illustrated SimCLR Framework
  2. SimCLR