持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第16天,点击查看活动详情
Introduction
前两天群里发了这个,用的stable diffusion的东西, 其中一部分用的是CLIP模型。
论文原文中有一段伪代码,写的很简洁,但是概括了CLIP模型的整个训练的工作流程,安安在梳理CLIP模型,我就分担一下伪代码这一部分的内容。
CLIP
在讲代码之前还是先看一下CLIP模型。
左边图是训练过程,右边是训练之后如何拿去做分类任务。
-
训练过程
可以看到左边是选择一个batch size的图片-文本对(text-image pair),将二者放进孪生网络里,进行对比学习。通过文本编码器获得对应的文本特征,图像也通过图像编码器获得对应的图像特征,二者组合起来之后形成一个组合特征矩阵,对角线上的元素是相对应的图像文本对,将其作为正样本;矩阵对角线元素之外的其他元素是不相关的图像和文本,我们将其作为负样本。
-
分类器
因为使用对比学习进行训练,因此CLIP出来的特征是没有分类头的,作者为此提出了一个解决办法。
首先是用prompt,将一个自定义的单词列表放进template中形成一个句子,然后使用训练好的编码器获得对应的一组文本特征。然后是将待分类的图片放入图片编码器,获得对应的特征。之后使用cosine similarly计算二者的相似度,获得对应的结果。
伪代码解析
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# t - learned temperature parameter
对比学习的孪生网络两个编码器:
-
图像编码器使用的ResNet或者vision transformer
-
文本编码器使用的是CBOW或者transformer
-
t就是一个温度参数,这里不用过多了解
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
这里是模型接收的输入,其中是batch size的大小。
所以这里是:
-
接收一个mini-batch的图像,图像张量的形状为
[n,h,w,c]
-
n:batch size
-
h:图片像素高
-
w:图片像素宽
-
c:图像通道数
-
-
接收一个mini-batch的文本,文本矩阵的形状为
[n,l]
-
n:batch size
-
l:文本的长度
-
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
拿到输入和编码器之后我们就可以进行特征抽取获得对应的特征(features)。
-
图像特征
T_f
-
文本特征
T_f
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
这一步就是获得合并为多模态的嵌入表示。先做了一个projection,之后继续一个归一下获得对应的embedding表示。
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
这里的两个映射用到的是w_i
和w_t
,主要作用就是让单模态的文本和数据转变为多模态。
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
之后我们使用两个embedding计算cosine similarly获得logits去做分类。
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
获得我们预测的logits之后我们需要一个ground truth。拿着logits和ground truth去算交叉熵目标函数(cross entropy loss),对模型进行学习即可。
其他
论文地址:[2103.00020] Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)
代码地址:openai/CLIP: Contrastive Language-Image Pretraining (github.com)