CLIP模型伪代码详细解析

2,529 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第16天,点击查看活动详情

Introduction

前两天群里发了这个,用的stable diffusion的东西, 其中一部分用的是CLIP模型。

image.png

论文原文中有一段伪代码,写的很简洁,但是概括了CLIP模型的整个训练的工作流程,安安在梳理CLIP模型,我就分担一下伪代码这一部分的内容。

CLIP

在讲代码之前还是先看一下CLIP模型。

image.png

左边图是训练过程,右边是训练之后如何拿去做分类任务。

  • 训练过程

    可以看到左边是选择一个batch size的图片-文本对(text-image pair),将二者放进孪生网络里,进行对比学习。通过文本编码器获得对应的文本特征,图像也通过图像编码器获得对应的图像特征,二者组合起来之后形成一个组合特征矩阵,对角线上的元素是相对应的图像文本对,将其作为正样本;矩阵对角线元素之外的其他元素是不相关的图像和文本,我们将其作为负样本。

  • 分类器

    因为使用对比学习进行训练,因此CLIP出来的特征是没有分类头的,作者为此提出了一个解决办法。

    首先是用prompt,将一个自定义的单词列表放进template中形成一个句子,然后使用训练好的编码器获得对应的一组文本特征T1,...,TNT_1,...,T_N。然后是将待分类的图片放入图片编码器,获得对应的特征I1I_1。之后使用cosine similarly计算二者的相似度,获得对应的结果。

伪代码解析

image.png

# image_encoder - ResNet or Vision Transformer 
# text_encoder - CBOW or Text Transformer 
# t - learned temperature parameter 

对比学习的孪生网络两个编码器:

  • 图像编码器使用的ResNet或者vision transformer

  • 文本编码器使用的是CBOW或者transformer

  • t就是一个温度参数,这里不用过多了解

# I[n, h, w, c] - minibatch of aligned images 
# T[n, l] - minibatch of aligned texts 

这里是模型接收的输入,其中nn是batch size的大小。

所以这里是:

  • 接收一个mini-batch的图像,图像张量的形状为[n,h,w,c]

    • n:batch size

    • h:图片像素高

    • w:图片像素宽

    • c:图像通道数

  • 接收一个mini-batch的文本,文本矩阵的形状为[n,l]

    • n:batch size

    • l:文本的长度

# extract feature representations of each modality 
I_f = image_encoder(I) #[n, d_i] 
T_f = text_encoder(T) #[n, d_t] 

拿到输入和编码器之后我们就可以进行特征抽取获得对应的特征(features)。

  • 图像特征T_f

  • 文本特征T_f

# joint multimodal embedding [n, d_e] 
I_e = l2_normalize(np.dot(I_f, W_i), axis=1) 
T_e = l2_normalize(np.dot(T_f, W_t), axis=1) 

这一步就是获得合并为多模态的嵌入表示。先做了一个projection,之后继续一个L2L_2归一下获得对应的embedding表示。

# W_i[d_i, d_e] - learned proj of image to embed 
# W_t[d_t, d_e] - learned proj of text to embed 

这里的两个映射用到的是w_iw_t,主要作用就是让单模态的文本和数据转变为多模态。

# scaled pairwise cosine similarities [n, n] 
logits = np.dot(I_e, T_e.T) * np.exp(t) 

之后我们使用两个embedding计算cosine similarly获得logits去做分类。

# symmetric loss function 
labels = np.arange(n) 
loss_i = cross_entropy_loss(logits, labels, axis=0) 
loss_t = cross_entropy_loss(logits, labels, axis=1) 
loss = (loss_i + loss_t)/2

获得我们预测的logits之后我们需要一个ground truth。拿着logits和ground truth去算交叉熵目标函数(cross entropy loss),对模型进行学习即可。

其他

论文地址:[2103.00020] Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)

代码地址:openai/CLIP: Contrastive Language-Image Pretraining (github.com)