搭建小型ViT网络构架进行分类任务(Pytorch)

1,056 阅读4分钟

一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第16天,点击查看活动详情

前言

  在这里我们不过多叙述原理,为了实验的便捷性,我们选择最为常见的MNIST数据集,Demo的构架我们采用我往期的动手撸个MNIST分类(CPU版本+GPU版本) 中GPU版本作为母版,详情可参考juejin.cn/post/707783…

Vision Transformers 大体框架

  为了能够使用自己的ViT模型应用到MNIST分类中去(替换class Net(nn.Module) 模块)可搭建如下框架:

class ViTNet(nn.Module)
    def __init__(self):
        super(ViTNet,self).__init__()
    
    def forward
        pass

Vision Transformers 细节架构

  在PyTorch中有大量的DL架构都提供Autograd计算,因此我们在Vision Transformers 模型中只需要着重在向前传递过程中花费精力即可;由于我在训练框架中定义了模型的优化器,PyTorch框架能够反向传播梯度并训练模型的参数。
我们将以下的五个重要步骤搭建出符合(MNIST)的网络结构:
第一步:Patchifying 和线性映射:

  Transformer 编码器一开始主要用于NLP这种序列化数据,将它用于CV领域的第一步要处理的是“序列化”图像,这里的处理方式是将一张图像分解成多个子图像,将每个子图像映射成一个向量。

  在MNIST数据集上,我们将每个(1x28x28)的图像分成7x7块,每块大小是4x4(如果不能完全整除分块,需要对图像padding填充),这样我们能从单个图像中获得49个子图像。将原图重塑成:

  (N, PxP, HxC/P x WxC/P) = (N, 7x7, 4x4) = (N, 49, 16)请注意,虽然每个子图大小为 1x4x4 ,但我们将其展平为 16 维向量。此外,MNIST只有一个颜色通道。如果有多个颜色通道,它们也会被展平到矢量中。

  我们得到展平后的patches即向量,通过一个线性映射来改变维度,线性映射可以映射到任意向量大小,我们向类构造函数添加一个hidden_d参数,用于“隐藏维度”。这里,使用隐藏维度为8,这样我们将每个 16 维patch映射到一个 8 维patch

第二步:添加分类标记

  在隐含层后,为了完成MNIST分类任务,必不可少的是添加分类的标记,在模型中添加一个参数将我们的Tensor(N,49,8)转变为Tensor(N,50,8);这里大家需要注意的一个地方是分类标记需要放在每个序列的第一个标记位。在完成MLP时,需要对应到相应的位置上。

第三步:添加位置编码

   紧接上一步,我们标记完成后需要进行添加位置编码,然而这块的理论性较强,强烈建议大家观摩transformer模型中的位置表明输出,这里我们就简化了,采用sin和cos替代。这里需要注意的地方是我们在第二部中转换完的Tensor(N,50,8),此时我们应该重复(50,8)的位置编码矩阵N次。

第四步:LN, MSA和残差连接

   这步较为复杂,我们在对tokens做归一化没然后采用多头注意力机制,最后添加一个残差连接输出。
LN:通过LN运行Tensor(N,50,8)后,每个50x8 矩阵的均值是0,标准差位1,维度保持不变。

   多头自注意力:对于每一张图像,都希望它能参与每个patch并在其中更新。在这里我不做过多注释,大家可参考MSA计算过程。

   残差连接:将添加一个残差连接,它将我们的原始Tensor(N,50,8)添加到在 LN 和 MSA 之后获得的 (N, 50, 8)。如果我们现在通过我们的模型运行MNIST的随机 (3, 1, 28, 28) 图像,我们仍然会得到形状为 (3, 50, 8) 的结果。

第五步:LN,MLP 和残差连接后进行MLP分类:

   这里就开始搭积木了。我们可以从 N 个序列中只提取分类标记(第一个标记),与添加分类标签的位置对应,并使用每个标记得到 N 个分类。

   由于我们决定每个标记是一个 8 维向量,并且由于我们有 10 个可能的数字,我们可以将分类 MLP 实现为一个简单的 8x10 矩阵,并使用 SoftMax 函数激活。