CLIP源码解析篇

344 阅读8分钟

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

CLIP源码解析篇

写在前面

Hello,大家好,我是小苏👦👦👦

在上一小节中,我已经为大家介绍了CLIP的原理,还不清楚的点击下文链接查看:

那么这一篇我将来为大家介绍介绍CLIP的代码,感兴趣的就随我一起往下看叭。🍭🍭🍭

在CLIP的论文中,给出了源码的地址,如下:CLIP源码🥗🥗🥗,此外还给了Google Colab的代码,直接可以运行,感兴趣的可以去看看。

但由于源码训练时间较长,本机也只能CPU跑,故在Github上找了一个简化版的使用mnist手写数字数据集的CLIP代码,写的比较简介,也能表达出CLIP的思想,因此本文将以此代码为例,为大家介绍CLIP在代码中是如何实现的。此代码的Github链接如下:mnist-clip🍄🍄🍄

CLIP源码

先来说该代码的注意事项,源码中dataset.py文件中将数据集转成Tensor格式用的PILToTensor,需要导入from torchvision.transforms.v2 import PILToTensor,Compose。但是我的torchvision版本不对,因此我还是用之前版本转Tensor格式的方法,修改如下图所示:

image-20241005210416163

还有一点,在train.py中,加载数据集时设置了num_workers=10,persistent_workers=True,在GPU上训练可以这么设置,但是我用的CPU,因此这里需要设置num_workers=0,persistent_workers=False


调整了这两处,你就可以直接训练代码了,以下是我训练的loss变换:

image-20241005210805983

训练完成后,会得到model.pth权重文件,可用于后续预测。


在CLIP的论文中,给出了训练的伪代码,我们可以先一起来看看,如下:

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# extract feature representations of each modality
# 提取图像特征
I_f = image_encoder(I) #[n, d_i]
# 提取文本特征
T_f = text_encoder(T) #[n, d_t]

# joint multimodal embedding [n, d_e]
# 分别将图像和文本特征投影到统一维度,并进行归一化
I_e = l2_normalize(np.dot(I_f, W_i), axis=1) # [n, d_i] * [d_i, d_e] = [n, d_e]
T_e = l2_normalize(np.dot(T_f, W_t), axis=1) # [n, d_t] * [d_t, d_e] = [n, d_e]

#计算图片-文字向量的余弦相似度
logits = np.dot(I_e, T_e.T) * np.exp(t) # [n, n]

#计算Loss
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

从上述的伪代码中可以看出,训练过程大概可以分为4步,如下:

  1. 提取图像和文本特征:提取图像特征需要图像编码器image_encoder,通常可以是ResNet或者Vision Transformer;提取文本特征需要文本编码器text_encoder,通常可以是CBOWTransformer
  2. 图像文本特征统一并归一化:利用投影矩阵W_i([d_i, d_e])将原来图像的维度由[n, d_i]变成[n, d_e],利用投影矩阵W_t([d_t, d_e])将原来图像的维度由[n, d_t]变成[n, d_e]。即经过投影后的特征维度是一致的,方便后面进行损失计算。这步骤进行完之后,还会进行一个L2归一化操作,这步是让每个向量的长度为1,进而方便后面计算图片文字向量的余弦相似度。
  3. 计算图片-文字向量的余弦相似度:通过计算图片向量和文字向量的点积来度量两个向量的相似程度,这个点积实际上就是两个向量的余弦相似度。因为余弦相似度的公式为:cosine_similarity(A,B)=ABA∥∥Bcosine\_similarity(A,B)=\frac{A⋅B}{∥A∥∥B∥}。即两个向量的点积除以两个向量模长的乘积,由于我们第二步进行了归一化,一次向量模长为1,所以这里的点积表示余弦相似度。这里还乘了np.exp(t),这里t 是一个学习到的温度参数。温度参数的作用是控制相似度的敏感度。当 t 越小,相似度差异会越显著;当 t 越大,相似度差异会减小。
  4. 计算Loss:首先会生成一个标签张量,形状为[n],值为[0, 1, 2, ..., n-1],这个标签表示图像和文本是配对的(即第 i 张图像与第 i 个文本应该匹配)。然后分别在图像维度和文本维度来计算交叉熵损失并将总损失设为两个维度损失的平均值。

对于第四点损失计算部分,我想大家还是有点云里雾里,这里在详细的介绍一下这两句:

loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)

首先cross_entropy_loss(logits, labels, axis=0)是在图像的维度上计算交叉熵损失,即每一行代表一个图像与所有文本的相似度分布。模型希望第 i 个图像与第 i 个文本的相似度最大,其他文本的相似度尽量小。因此,在图像的维度上,交叉熵损失将优化使得第 i 行的 i 位置(正确的文本)的概率最大化。

同样cross_entropy_loss(logits, labels, axis=1)是在文本的维度上计算交叉熵损失,类似地,表示第 i 个文本与所有图像的相似度分布。模型希望第 i 个文本与第 i 个图像的相似度最大化。因此,在文本的维度上,交叉熵损失优化使得每列的第 i 个位置(正确的图像)的概率最大化。


介绍完论文中的伪代码,我们来看看mnist-clip中的代码,首先是图像编码器和文本编码器的设置,如下:

# 图像编码器
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=stride)
        self.bn1=nn.BatchNorm2d(out_channels)
        
        self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        
        self.conv3=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,padding=0,stride=stride)
    
    def forward(self,x):
        y=F.relu(self.bn1(self.conv1(x)))
        y=self.bn2(self.conv2(y))
        z=self.conv3(x)
        return F.relu(y+z)
        

class ImgEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.res_block1=ResidualBlock(in_channels=1,out_channels=16,stride=2) # (batch,16,14,14)
        self.res_block2=ResidualBlock(in_channels=16,out_channels=4,stride=2) # (batch,4,7,7)
        self.res_block3=ResidualBlock(in_channels=4,out_channels=1,stride=2) # (batch,1,4,4)
        self.wi=nn.Linear(in_features=16,out_features=8)
        self.ln=nn.LayerNorm(8)
        
    def forward(self,x):
        x=self.res_block1(x)
        x=self.res_block2(x)
        x=self.res_block3(x)
        x=self.wi(x.view(x.size(0),-1))
        x=self.ln(x)
        return x

图像编码器的结构很简单,就由一些残差层和全连接层构成,并且加了一个层归一化层,这样后面在第2步时就不需要归一化了。

# 文本编码器
class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb=nn.Embedding(num_embeddings=10,embedding_dim=16)
        self.dense1=nn.Linear(in_features=16,out_features=64)
        self.dense2=nn.Linear(in_features=64,out_features=16)
        self.wt=nn.Linear(in_features=16,out_features=8)
        self.ln=nn.LayerNorm(8)
    
    def forward(self,x):
        x=self.emb(x)
        x=F.relu(self.dense1(x))
        x=F.relu(self.dense2(x))
        x=self.wt(x)
        x=self.ln(x)
        return x

文本编码器这里设置的也很简单,就是一个Embedding层和一个全连接层,当然最后也加了一个层归一化操作。

最终的CLIP模型如下,就是计算图像-文本特征的余弦相似度,得到logits

class CLIP(nn.Module):
    def __init__(self,):
        super().__init__()
        self.img_enc=ImgEncoder()
        self.text_enc=TextEncoder()

    def forward(self,img_x,text_x):
        img_emb=self.img_enc(img_x)
        text_emb=self.text_enc(text_x)
        return img_emb@text_emb.T

最后,我们来看看训练损失计算的关键代码,如下:

logits=model(imgs.to(DEVICE),labels.to(DEVICE))
targets=torch.arange(0,TARGET_COUNT).to(DEVICE)
loss_i=F.cross_entropy(logits,targets)
loss_t=F.cross_entropy(logits.permute(1,0),targets)
loss=(loss_i+loss_t)/2

其实你可以发现,这里的代码和伪代码是完全一致的,只是这里用的是logits.permute(1,0)将矩阵转置了,和之前设置axis=1效果一致。

Mnist-CLIP预测效果

我们来看看上述代码的预测效果,首先可以由图片预测出类别,即分类任务,主要代码如下:

'''
1、对图片分类
'''
image,label=dataset[1]
print('正确分类:',label)
plt.imshow(image.permute(1,2,0))
plt.show()

targets=torch.arange(0,10)  #10种分类
logits=model(image.unsqueeze(0).to(DEVICE),targets.to(DEVICE)) # 1张图片 vs 10种分类
print(logits)
print('CLIP分类:',logits.argmax(-1).item())

'''

这部分其实就和我们原理部分介绍的完全一致,拿图片去和所有类别计算相似度,然后相似度最大对应的类别即是所需,我们可以看看输出结果,如下:【dataset[1]是取数据集中第一张图片,这里图片是0】

image-20241005221519986

image-20241005221550844

可以看到,CLIP预测的类别和正确标签是一致的。


除了做图像分类之外,还可以做其它任务,如相似图像查找,关键代码如下:

'''
2、图像相似度
'''
other_images=[]
other_labels=[]
for i in range(1,101):
    other_image,other_label=dataset[i]
    other_images.append(other_image)
    other_labels.append(other_label)

# 其他100张图片的向量
other_img_embs=model.img_enc(torch.stack(other_images,dim=0).to(DEVICE))

# 当前图片的向量
img_emb=model.img_enc(image.unsqueeze(0).to(DEVICE))

# 计算当前图片和100张其他图片的相似度
logtis=img_emb@other_img_embs.T
values,indexs=logtis[0].topk(5) # 5个最相似的

plt.figure(figsize=(15,15))
for i,img_idx in enumerate(indexs):
    plt.subplot(1,5,i+1)
    plt.imshow(other_images[img_idx].permute(1,2,0))
    plt.title(other_labels[img_idx])
    plt.axis('off')
plt.show()

这段代码也比较简单,就是查找某张图片最相近的5张图片,并画图展示出来,预测结果如下:

可以看到,CLIP模型找出了最相近的5个0,这与预期是相符的。🍚🍚🍚

小结

呼呼呼~~~到这里就把CLIP源码详解篇介绍完啦,如果还有不懂的地方欢迎评论区留言讨论,我们下期再见。🌼🌼🌼

如若文章对你有所帮助,那就🛴🛴🛴

         一键三连 (1).gif