概述
Vision Transformer 是最早将 transformer 应用于计算机视觉领域的开山算法之一。将图像处理问题转化为序列处理问题,为后续 DETR、SegFormer、BevFormer、MapTR 等视觉感知中各模块的突破打下基础,推动端到端的检测、分割、建图等,使得当下整个辅助驾驶链路,包括感知、跟踪、地图、决策、规划等模块能够有机会形成一个真正的端到端框架,例如上海 AI Lab 提出的 UniAD 等等。
核心流程
def forward(self, img):
# step1: 图像块特征编码
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# step2:添加 cls token 编码
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
# step3:图像块位置编码
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# step4:transformer 编码
x = self.transformer(x)
# step5:获取整张图片的表示
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
# step6:分类头输出
x = self.to_latent(x)
return self.mlp_head(x)
ViT 主体结构
patch embedding
# 假定 forward 输入:
# img = torch.randn(2, 3, 256, 256) 等同于 b c (h p1) (w p2)
# patch_dim 为 3,dim 为 1024
def forward(self, img):
# step1: 图像块特征编码
x = self.to_patch_embedding(img)
b, n, _ = x.shape
#########################################
# to_patch_embedding 内部:
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
- Rearrange 图像划分: 将图像划分成固定大小的 patch,得到 b 个 batch_size 下,h*w 个 patch,每个 patch 实际为高 p1、宽 p2、通道数为 c 的 tensor
- Rearrange 展平: 每个 patch 被展平成一维向量:空间位置的展平 (h w),patch内容的展平 (p1 p2 c)
注意这里 Rearrange 的机制:
输入 [b, c, hp1, wp2],输出 [b, hw, p1p2*c]
假定 forward 传入 img = torch.randn(2, 3, 256, 256),Rearrange 会按照 [b, c, hp1, wp2] 进行自动推导,计算得到 b、c、h、w
- LayerNorm 归一化原始 patch 特征
- Linear 线性投影: 通过线性层将 patch 映射到模型维度 dim,转换成模型所需的高维特征表示
- 再次 LayerNorm 归一化,输出维度 [b, num_patches, dim]
注意 transformer 系列 LayerNorm 和 CNN 系列 BatchNorm 的区别:
数据结构匹配:
BatchNorm 设计用于 CNN,期望输入是 [B, C, H, W] 格式
而这里数据已经被重排为 [B, N, D] 格式(N是patch数,D是特征维度),LayerNorm 则更适合处理这种序列数据
归一化维度不同:
- BatchNorm:在 batch 和空间维度上计算统计量,对每个通道单独归一化
- LayerNorm:对每个样本的特征维度独立归一化,不依赖 batch 大小。在 transformer 中,我们希望对每个 patch 的所有特征进行归一化。
cls token
# step2:添加 cls token 编码
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
#########################################
# cls token 初始化为:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
cls token 本质上是一组可学习的向量,含义是全局信息的聚合器,可以类比为 NLP 中在句子开头加上 "[START]" 标记。
- repeat: 让模型学会用同一组 token 来表征不同图片的全局特征。
这里同样的,repeat 也会从输入的 1 1 d 中自动推导 d 为 dim,与 x 的特征维度一致
-
concat: 将 cls_tokens 和 patch_embedding 拼接,得到总体特征和 patch 特征的组合,输出维度 [b, num_patches+1, dim]
pos embedding
# step3:图像块位置编码
# 这里 n 是指 patch 个数
x += self.x += self.pos_embedding[:, :(n + 1)][:, :(n + 1)]
x = self.dropout(x)
#########################################
# pos_embedding 初始化为:
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
由 pos_embedding 初始化定义可知,其形状是 [1, num_patches + 1, dim]。[:, :(n + 1)]通过多维切片操作,得到 self.pos_embedding[:, :(n + 1)] 的维度是 [1, 0~n, dim]。
注意这个切片操作:
第三个维度(隐含的 :),虽然没写,但默认取所有元素
-
x += pos_embedding: 增强现有特征,相当于给每个 patch 特征添加位置信息
-
dropout: 训练 trick
transformer
# step4:transformer 编码
x = self.transformer(x)
#########################################
# transformer 内部的 forward 逻辑:
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
transformer 由多个 layers 堆叠而成,每个 layer 包含一个 attention 模块 attn 和一个前馈神经网络模块 ff。forward 过程,通过残差结构,每一层都在重新叠加组合前一层的特征,通过多层堆叠,可以学习到更复杂的特征组合。最后,同样在输出前做了 LayerNorm。
attn
def forward(self, x):
x = self.norm(x)
x = self.to_qkv(x)
qkv = x.chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
#########################################
# to_qkv 定义:
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# inner_dim 定义:
inner_dim = dim_head * heads
# self.attend 定义:
self.attend = nn.Softmax(dim = -1)
# self.to_out 定义:
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
-
norm: 对输入进行 LayerNorm
-
to_qkv: 将输入的 x(此时已经融合了 patch embedding、cls token、pos embedding)从 dim 维线性映射到 inner_dim * 3
- 这里,3 的含义是 Q、K、V,这三者的维度都是 inner_dim
- inner_dim 内部把每个 x(即 patch 形式的 token)的 Q、K、V 向量划分成了 heads 个子空间,每个子空间的维度是 dim_head,每个子空间对应一个 head。
- 此外,inner_dim 在数值上等同于 x 的维度 dim,使得 x 作为 Q、K、V 的来源。
-
chunk: 将 tensor 沿最后一维,拆分成 3 份。
-
q, k, v: qkv 含 3 个 tensor,即 q、k、v;用 rearrange 进行维度重排,输出维度 [b h n d],即 [batch_size,heads,num_patches,dim_head]
-
dots、attn、out: 这几步都严格遵循多头自注意力机制的计算公式
-
out rearrange、to_out: 负责把多头注意力输出的高维特征投影还原回到主干维度
注意 to_out:
只有单头且维度完全一致时,才可以直接输出,不需要线性变换。这是因为多头注意力会把每个 head 的输出拼接起来,得到一个更高维的向量(inner_dim),需要用线性层把它还原回主干维度(dim),否则后续残差结构无法对齐。
ff
def forward(self, x):
return self.net(x)
#########################################
# self.net 定义
self.net = nn.Sequential(
nn.LayerNorm(dim), # 1. 归一化
nn.Linear(dim, hidden_dim), # 2. 升维
nn.GELU(), # 3. 非线性激活
nn.Dropout(dropout), # 4. 防止过拟合
nn.Linear(hidden_dim, dim), # 5. 降维
nn.Dropout(dropout) # 6. 再次防止过拟合
)
比较经典的前馈神经网络结构,最后一层不加激活,配合残差结构,保证信息和梯度的直接流动。
pool
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
有 mean 和 cls 两种 pool 方式:
mean:对所有 token 做平均池化,作为整张图片的表示
cls:使用第一个 token(即 cls token)作为整张图片的表示(和 BERT 类似)
输出
x = self.to_latent(x)
return self.mlp_head(x)
#########################################
# self.to_latent 定义
self.to_latent = nn.Identity()
# self.mlp_head 定义
self.mlp_head = nn.Linear(dim, num_classes)
-
to_latent: 在这边是 Identity 操作。同时也是一个预留的接口,允许用户在不修改原始代码的情况下,通过继承 ViT 类来扩展功能,比如可以修改这个 Identity 操作,在特征进入最后分类头之前进行额外处理
-
mlp_head: 通过简单的线性层完成最终分类
总结核心流程
-
图像分块与特征提取:
- 将输入图像切分成固定大小的 patches
- 通过 to_patch_embedding 提取每个 patch 的特征表示(内容编码)
- 添加 pos_embedding 为每个 patch 注入位置信息(位置编码)
-
序列构建与增强:
- 添加可学习的 cls_token 作为全局特征聚合点;
- 将位置编码与特征编码相加,形成完整的 patch token;
- 最终得到包含空间和语义信息的序列表示
-
Transformer 特征提取:
- 通过多头自注意力机制实现 patches 间的全局交互
- 使用 FeedForward 网络增强特征表示能力
- 采用残差连接和 LayerNorm 确保深层网络的训练稳定性
-
分类预测:
-
可选择使用 cls token 或平均池化获取全局特征
-
通过简单的线性层完成最终分类
-
参考
ViT 代码:github.com/lucidrains/…
ViT 论文:openreview.net/pdf?id=Yicb…
transformer 讲解:github.com/datawhalech…
BETR 讲解:book.douban.com/subject/362…
GELU 论文:arxiv.org/pdf/1606.08…