ViT模型张量维度变化过程和代码解析

93 阅读5分钟

之前挖坑说要详细讨论一下ViT的计算过程,现在填坑。


分析

以论文中的模型图为例子。

1728695300777.jpg

1. 原始图像

一张初始图片,假设是303030*30像素的,那按照示例图分为9个patches,那一个patch就是101010*10像素的。每个patch的向量长度是10103=30010*10*3=300,乘3因为图像是三通道。最后展平的就是33003*300的长度。

image.png

2. 图像展平

将patches展平之后,线性映射阶段是进行一个简单映射,不会改变张量的形状。因此图像展平线性映射这里的权重矩阵大小是300300300*300

image.png

3. tokens转换

将patches线性映射后转化为传递给Transformer的tokens。左侧可以看到因为是要添加额外的分类头[class],因此长度变为1030010*300。这里添加位置编码是使用add直接相加,不是concate拼接方式。

image.png

4. Transformer Block

image.png

Transformer norm 不会涉及维度变化。

多头注意力部分的QKV维度大小也是延续Transformer的计算方法,输入是1030010*300,假设是两个注意力头,那QKV的长度就是10(300/2)=1015010*(300/2)=10*150,ViT-BASE是八个头,那QKV的维度就是10(300/8)10*(300/8),举例不恰当,除不出整数,你们自行理解一下就行。最后将每个注意力头的结果拼接起来,就又恢复了1030010*300的维度大小。

最后的MLP部分会进行维度的放缩。一般是先放大到四倍再缩小回初始维度,因此这一步会变成10120010*1200后再回到1030010*300

image.png

5. 分类头计算

最后使用[class]进行分类。

image.png


以上就是整个计算过程的维度变化,最后可以根据论文中的计算图再捋一下。

image.png


代码

论文源码链接:vision_transformer/vit_jax/models_vit.py at main · google-research/vision_transformer (github.com)

接下来从代码部分说一下维度转化过程:


class VisionTransformer(nn.Module):
  """VisionTransformer."""

  num_classes: int
  patches: Any
  transformer: Any
  hidden_size: int
  resnet: Optional[Any] = None
  representation_size: Optional[int] = None
  classifier: str = 'token'
  head_bias_init: float = 0.
  encoder: Type[nn.Module] = Encoder
  model_name: Optional[str] = None

  @nn.compact
  def __call__(self, inputs, *, train):

    x = inputs
    # (Possibly partial) ResNet root.
    if self.resnet is not None:
      width = int(64 * self.resnet.width_factor)

      # Root block.
      x = models_resnet.StdConv(
          features=width,
          kernel_size=(7, 7),
          strides=(2, 2),
          use_bias=False,
          name='conv_root')(
              x)
      x = nn.GroupNorm(name='gn_root')(x)
      x = nn.relu(x)
      x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME')

      # ResNet stages.
      if self.resnet.num_layers:
        x = models_resnet.ResNetStage(
            block_size=self.resnet.num_layers[0],
            nout=width,
            first_stride=(1, 1),
            name='block1')(
                x)
        for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
          x = models_resnet.ResNetStage(
              block_size=block_size,
              nout=width * 2**i,
              first_stride=(2, 2),
              name=f'block{i + 1}')(
                  x)

    n, h, w, c = x.shape

    # We can merge s2d+emb into a single conv; it's the same.
    x = nn.Conv(
        features=self.hidden_size,
        kernel_size=self.patches.size,
        strides=self.patches.size,
        padding='VALID',
        name='embedding')(
            x)

    # Here, x is a grid of embeddings.

    # (Possibly partial) Transformer.
    if self.transformer is not None:
      n, h, w, c = x.shape
      x = jnp.reshape(x, [n, h * w, c])

      # If we want to add a class token, add it here.
      if self.classifier in ['token', 'token_unpooled']:
        cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

      x = self.encoder(name='Transformer', **self.transformer)(x, train=train)

    if self.classifier == 'token':
      x = x[:, 0]
    elif self.classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
    elif self.classifier in ['unpooled', 'token_unpooled']:
      pass
    else:
      raise ValueError(f'Invalid classifier={self.classifier}')

    if self.representation_size is not None:
      x = nn.Dense(features=self.representation_size, name='pre_logits')(x)
      x = nn.tanh(x)
    else:
      x = IdentityLayer(name='pre_logits')(x)

    if self.num_classes:
      x = nn.Dense(
          features=self.num_classes,
          name='head',
          kernel_init=nn.initializers.zeros,
          bias_init=nn.initializers.constant(self.head_bias_init))(x)
    return x

1. VisionTransformer 的输入和 ResNet 根网络

输入 inputs 的维度一般为 (n, h, w, c),即:

  • n:批量大小 (batch size)
  • h:图像的高度 (height)
  • w:图像的宽度 (width)
  • c:图像的通道数 (channels),通常为 3(对应 RGB 图像)

如果 resnet 不为 None,会先通过 ResNet 提取特征:

  • ResNet 会通过卷积、归一化、激活和池化操作处理输入图像。这些操作通常会减小图像的空间维度(高度 h 和宽度 w),增加通道维度 c

假设经过 ResNet 后,图像变成了 (n, h', w', c'),其中 h'w' 是经过池化后图像的高度和宽度,c' 是 ResNet 提取的特征通道。

2. 图像块划分 (Patches)

模型通过卷积层将图像划分成固定大小的块:

x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID')(x)

这一步通过卷积操作,将图像划分为多个 self.patches.size 大小的块,每个块变成 self.hidden_size 维度的特征向量。

假设 self.patches.size = (p, p),划分后的图像块的维度为 (n, h'', w'', hidden_size),其中:

  • n:批量大小不变
  • h'' = h' / p:块的个数(高度方向)
  • w'' = w' / p:块的个数(宽度方向)
  • hidden_size:每个图像块的嵌入维度

3. 将图像块展开为序列

为了将图像块输入到 Transformer 中,必须将 2D 图像块展平成序列。因此,模型将 (h'', w'') 两个维度展平为一个维度:

x = jnp.reshape(x, [n, h * w, c])

此时,维度变为 (n, num_patches, hidden_size),其中:

  • num_patches = h'' * w'':图像块的总数
  • hidden_size:每个图像块的嵌入维度不变

4. 添加类 token

如果模型使用类 token 分类(self.classifier == 'token'),会在序列的起始位置插入一个类 token,代表全局的图像信息。添加类 token 后,序列维度变为:

  • (n, num_patches + 1, hidden_size)

5. Transformer 编码器

接下来,图像块序列输入到 Transformer 编码器中,每个 Encoder1DBlock 操作不会改变序列的长度和嵌入维度。因此,经过 Transformer 编码器后,维度依旧保持:

  • (n, num_patches + 1, hidden_size)

6. 分类

  • token 分类器:如果使用类 token 分类,模型会提取序列的第一个 token(类 token),维度变为 (n, hidden_size),然后通过全连接层输出类别,最后维度变为 (n, num_classes)
  • gap 分类器:如果使用全局平均池化(gap),会对所有的图像块做池化,维度变为 (n, hidden_size),最后输出类别为 (n, num_classes)

7. 最终输出

如果使用 representation_size,模型会进一步通过全连接层改变输出特征的维度为 representation_size,否则直接进行分类。

总结维度变化

  • 输入图像维度为 (n, h, w, c)
  • ResNet 后维度变为 (n, h', w', c')
  • 图像块划分后维度为 (n, h'', w'', hidden_size)
  • 展平图像块为序列后,维度为 (n, num_patches, hidden_size)
  • 如果添加类 token,维度变为 (n, num_patches + 1, hidden_size)
  • Transformer 编码器不改变维度,仍为 (n, num_patches + 1, hidden_size)
  • 最终输出的维度为 (n, num_classes)