之前挖坑说要详细讨论一下ViT的计算过程,现在填坑。
分析
以论文中的模型图为例子。
1. 原始图像
一张初始图片,假设是像素的,那按照示例图分为9个patches,那一个patch就是像素的。每个patch的向量长度是,乘3因为图像是三通道。最后展平的就是的长度。
2. 图像展平
将patches展平之后,线性映射阶段是进行一个简单映射,不会改变张量的形状。因此图像展平线性映射这里的权重矩阵大小是。
3. tokens转换
将patches线性映射后转化为传递给Transformer的tokens。左侧可以看到因为是要添加额外的分类头[class]
,因此长度变为。这里添加位置编码是使用add直接相加,不是concate拼接方式。
4. Transformer Block
Transformer norm 不会涉及维度变化。
多头注意力部分的QKV维度大小也是延续Transformer的计算方法,输入是,假设是两个注意力头,那QKV的长度就是,ViT-BASE是八个头,那QKV的维度就是,举例不恰当,除不出整数,你们自行理解一下就行。最后将每个注意力头的结果拼接起来,就又恢复了的维度大小。
最后的MLP部分会进行维度的放缩。一般是先放大到四倍再缩小回初始维度,因此这一步会变成后再回到。
5. 分类头计算
最后使用[class]
进行分类。
以上就是整个计算过程的维度变化,最后可以根据论文中的计算图再捋一下。
代码
论文源码链接:vision_transformer/vit_jax/models_vit.py at main · google-research/vision_transformer (github.com)
接下来从代码部分说一下维度转化过程:
class VisionTransformer(nn.Module):
"""VisionTransformer."""
num_classes: int
patches: Any
transformer: Any
hidden_size: int
resnet: Optional[Any] = None
representation_size: Optional[int] = None
classifier: str = 'token'
head_bias_init: float = 0.
encoder: Type[nn.Module] = Encoder
model_name: Optional[str] = None
@nn.compact
def __call__(self, inputs, *, train):
x = inputs
# (Possibly partial) ResNet root.
if self.resnet is not None:
width = int(64 * self.resnet.width_factor)
# Root block.
x = models_resnet.StdConv(
features=width,
kernel_size=(7, 7),
strides=(2, 2),
use_bias=False,
name='conv_root')(
x)
x = nn.GroupNorm(name='gn_root')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME')
# ResNet stages.
if self.resnet.num_layers:
x = models_resnet.ResNetStage(
block_size=self.resnet.num_layers[0],
nout=width,
first_stride=(1, 1),
name='block1')(
x)
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
x = models_resnet.ResNetStage(
block_size=block_size,
nout=width * 2**i,
first_stride=(2, 2),
name=f'block{i + 1}')(
x)
n, h, w, c = x.shape
# We can merge s2d+emb into a single conv; it's the same.
x = nn.Conv(
features=self.hidden_size,
kernel_size=self.patches.size,
strides=self.patches.size,
padding='VALID',
name='embedding')(
x)
# Here, x is a grid of embeddings.
# (Possibly partial) Transformer.
if self.transformer is not None:
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# If we want to add a class token, add it here.
if self.classifier in ['token', 'token_unpooled']:
cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
x = self.encoder(name='Transformer', **self.transformer)(x, train=train)
if self.classifier == 'token':
x = x[:, 0]
elif self.classifier == 'gap':
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
elif self.classifier in ['unpooled', 'token_unpooled']:
pass
else:
raise ValueError(f'Invalid classifier={self.classifier}')
if self.representation_size is not None:
x = nn.Dense(features=self.representation_size, name='pre_logits')(x)
x = nn.tanh(x)
else:
x = IdentityLayer(name='pre_logits')(x)
if self.num_classes:
x = nn.Dense(
features=self.num_classes,
name='head',
kernel_init=nn.initializers.zeros,
bias_init=nn.initializers.constant(self.head_bias_init))(x)
return x
1. VisionTransformer 的输入和 ResNet 根网络
输入 inputs
的维度一般为 (n, h, w, c)
,即:
n
:批量大小 (batch size)h
:图像的高度 (height)w
:图像的宽度 (width)c
:图像的通道数 (channels),通常为 3(对应 RGB 图像)
如果 resnet
不为 None
,会先通过 ResNet 提取特征:
- ResNet 会通过卷积、归一化、激活和池化操作处理输入图像。这些操作通常会减小图像的空间维度(高度
h
和宽度w
),增加通道维度c
。
假设经过 ResNet 后,图像变成了 (n, h', w', c')
,其中 h'
和 w'
是经过池化后图像的高度和宽度,c'
是 ResNet 提取的特征通道。
2. 图像块划分 (Patches)
模型通过卷积层将图像划分成固定大小的块:
x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID')(x)
这一步通过卷积操作,将图像划分为多个 self.patches.size
大小的块,每个块变成 self.hidden_size
维度的特征向量。
假设 self.patches.size = (p, p)
,划分后的图像块的维度为 (n, h'', w'', hidden_size)
,其中:
n
:批量大小不变h'' = h' / p
:块的个数(高度方向)w'' = w' / p
:块的个数(宽度方向)hidden_size
:每个图像块的嵌入维度
3. 将图像块展开为序列
为了将图像块输入到 Transformer 中,必须将 2D 图像块展平成序列。因此,模型将 (h'', w'')
两个维度展平为一个维度:
x = jnp.reshape(x, [n, h * w, c])
此时,维度变为 (n, num_patches, hidden_size)
,其中:
num_patches = h'' * w''
:图像块的总数hidden_size
:每个图像块的嵌入维度不变
4. 添加类 token
如果模型使用类 token 分类(self.classifier == 'token'
),会在序列的起始位置插入一个类 token,代表全局的图像信息。添加类 token 后,序列维度变为:
(n, num_patches + 1, hidden_size)
5. Transformer 编码器
接下来,图像块序列输入到 Transformer
编码器中,每个 Encoder1DBlock
操作不会改变序列的长度和嵌入维度。因此,经过 Transformer 编码器后,维度依旧保持:
(n, num_patches + 1, hidden_size)
6. 分类
token
分类器:如果使用类 token 分类,模型会提取序列的第一个 token(类 token),维度变为(n, hidden_size)
,然后通过全连接层输出类别,最后维度变为(n, num_classes)
。gap
分类器:如果使用全局平均池化(gap
),会对所有的图像块做池化,维度变为(n, hidden_size)
,最后输出类别为(n, num_classes)
。
7. 最终输出
如果使用 representation_size
,模型会进一步通过全连接层改变输出特征的维度为 representation_size
,否则直接进行分类。
总结维度变化
- 输入图像维度为
(n, h, w, c)
。 - ResNet 后维度变为
(n, h', w', c')
。 - 图像块划分后维度为
(n, h'', w'', hidden_size)
。 - 展平图像块为序列后,维度为
(n, num_patches, hidden_size)
。 - 如果添加类 token,维度变为
(n, num_patches + 1, hidden_size)
。 - Transformer 编码器不改变维度,仍为
(n, num_patches + 1, hidden_size)
。 - 最终输出的维度为
(n, num_classes)
。