【代码走读】Vision Transformer-Google-ICLR 2021

82 阅读7分钟

概述

Vision Transformer 是最早将 transformer 应用于计算机视觉领域的开山算法之一。将图像处理问题转化为序列处理问题,为后续 DETR、SegFormer、BevFormer、MapTR 等视觉感知中各模块的突破打下基础,推动端到端的检测、分割、建图等,使得当下整个辅助驾驶链路,包括感知、跟踪、地图、决策、规划等模块能够有机会形成一个真正的端到端框架,例如上海 AI Lab 提出的 UniAD 等等。

核心流程

def forward(self, img):
    # step1: 图像块特征编码
    x = self.to_patch_embedding(img)
    b, n, _ = x.shape

    # step2:添加 cls token 编码
    cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
    x = torch.cat((cls_tokens, x), dim=1)
    
    # step3:图像块位置编码
    x += self.pos_embedding[:, :(n + 1)]
    x = self.dropout(x)

    # step4:transformer 编码
    x = self.transformer(x)

    # step5:获取整张图片的表示
    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

    # step6:分类头输出
    x = self.to_latent(x)
    return self.mlp_head(x)

ViT 主体结构

patch embedding

# 假定 forward 输入:
#    img = torch.randn(2, 3, 256, 256) 等同于 b c (h p1) (w p2)
#    patch_dim 为 3,dim 为 1024
def forward(self, img):
    # step1: 图像块特征编码
    x = self.to_patch_embedding(img)
    b, n, _ = x.shape

#########################################
    # to_patch_embedding 内部:
    self.to_patch_embedding = nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
        nn.LayerNorm(patch_dim),
        nn.Linear(patch_dim, dim),
        nn.LayerNorm(dim),
    )
  1. Rearrange 图像划分: 将图像划分成固定大小的 patch,得到 b 个 batch_size 下,h*w 个 patch,每个 patch 实际为高 p1、宽 p2、通道数为 c 的 tensor
  2. Rearrange 展平: 每个 patch 被展平成一维向量:空间位置的展平 (h w),patch内容的展平 (p1 p2 c)

注意这里 Rearrange 的机制:

输入 [b, c, hp1, wp2],输出 [b, hw, p1p2*c]

假定 forward 传入 img = torch.randn(2, 3, 256, 256),Rearrange 会按照 [b, c, hp1, wp2] 进行自动推导,计算得到 b、c、h、w

  1. LayerNorm 归一化原始 patch 特征
  2. Linear 线性投影: 通过线性层将 patch 映射到模型维度 dim,转换成模型所需的高维特征表示
  3. 再次 LayerNorm 归一化,输出维度 [b, num_patches, dim]

注意 transformer 系列 LayerNorm 和 CNN 系列 BatchNorm 的区别:

  1. 数据结构匹配:

  • BatchNorm 设计用于 CNN,期望输入是 [B, C, H, W] 格式

  • 而这里数据已经被重排为 [B, N, D] 格式(N是patch数,D是特征维度),LayerNorm 则更适合处理这种序列数据

  1. 归一化维度不同:

  • BatchNorm:在 batch 和空间维度上计算统计量,对每个通道单独归一化
  • LayerNorm:对每个样本的特征维度独立归一化,不依赖 batch 大小。在 transformer 中,我们希望对每个 patch 的所有特征进行归一化。

cls token

    # step2:添加 cls token 编码
    cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
    x = torch.cat((cls_tokens, x), dim=1)
    
#########################################
    # cls token 初始化为:
    self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

cls token 本质上是一组可学习的向量,含义是全局信息的聚合器,可以类比为 NLP 中在句子开头加上 "[START]" 标记。

  1. repeat: 让模型学会用同一组 token 来表征不同图片的全局特征。

这里同样的,repeat 也会从输入的 1 1 d 中自动推导 d 为 dim,与 x 的特征维度一致

  1. concat: 将 cls_tokens 和 patch_embedding 拼接,得到总体特征和 patch 特征的组合,输出维度 [b, num_patches+1, dim]

pos embedding

    # step3:图像块位置编码
    # 这里 n 是指 patch 个数
    x += self.x += self.pos_embedding[:, :(n + 1)][:, :(n + 1)]
    x = self.dropout(x)

#########################################
    # pos_embedding 初始化为:
    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

由 pos_embedding 初始化定义可知,其形状是 [1, num_patches + 1, dim]。[:, :(n + 1)]通过多维切片操作,得到 self.pos_embedding[:, :(n + 1)] 的维度是 [1, 0~n, dim]。

注意这个切片操作:

第三个维度(隐含的 :),虽然没写,但默认取所有元素

  1. x += pos_embedding: 增强现有特征,相当于给每个 patch 特征添加位置信息

  2. dropout: 训练 trick

transformer

    # step4:transformer 编码
    x = self.transformer(x)

#########################################
    # transformer 内部的 forward 逻辑:
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
    
        return self.norm(x)

transformer 由多个 layers 堆叠而成,每个 layer 包含一个 attention 模块 attn 和一个前馈神经网络模块 ff。forward 过程,通过残差结构,每一层都在重新叠加组合前一层的特征,通过多层堆叠,可以学习到更复杂的特征组合。最后,同样在输出前做了 LayerNorm。

attn

    def forward(self, x):
        x = self.norm(x)
    
        x = self.to_qkv(x)
        qkv = x.chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
    
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
    
        attn = self.attend(dots)
        attn = self.dropout(attn)
    
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

#########################################
    # to_qkv 定义:
    self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
    
    # inner_dim 定义:
    inner_dim = dim_head * heads
    
    # self.attend 定义:
    self.attend = nn.Softmax(dim = -1)
    
    # self.to_out 定义:
    self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
  1. norm: 对输入进行 LayerNorm

  2. to_qkv: 将输入的 x(此时已经融合了 patch embedding、cls token、pos embedding)从 dim 维线性映射到 inner_dim * 3

    1. 这里,3 的含义是 Q、K、V,这三者的维度都是 inner_dim
    2. inner_dim 内部把每个 x(即 patch 形式的 token)的 Q、K、V 向量划分成了 heads 个子空间,每个子空间的维度是 dim_head,每个子空间对应一个 head。
    3. 此外,inner_dim 在数值上等同于 x 的维度 dim,使得 x 作为 Q、K、V 的来源。
  3. chunk: 将 tensor 沿最后一维,拆分成 3 份。

  4. q, k, v: qkv 含 3 个 tensor,即 q、k、v;用 rearrange 进行维度重排,输出维度 [b h n d],即 [batch_size,heads,num_patches,dim_head]

  5. dots、attn、out: 这几步都严格遵循多头自注意力机制的计算公式

  6. out rearrange、to_out: 负责把多头注意力输出的高维特征投影还原回到主干维度

注意 to_out:

只有单头且维度完全一致时,才可以直接输出,不需要线性变换。这是因为多头注意力会把每个 head 的输出拼接起来,得到一个更高维的向量(inner_dim),需要用线性层把它还原回主干维度(dim),否则后续残差结构无法对齐。

ff

    def forward(self, x):
        return self.net(x)
        
#########################################   
    # self.net 定义
    self.net = nn.Sequential(
        nn.LayerNorm(dim),            # 1. 归一化
        nn.Linear(dim, hidden_dim),   # 2. 升维
        nn.GELU(),                    # 3. 非线性激活
        nn.Dropout(dropout),          # 4. 防止过拟合
        nn.Linear(hidden_dim, dim),   # 5. 降维
        nn.Dropout(dropout)           # 6. 再次防止过拟合
    )

比较经典的前馈神经网络结构,最后一层不加激活,配合残差结构,保证信息和梯度的直接流动。

pool

    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

有 mean 和 cls 两种 pool 方式:

mean:对所有 token 做平均池化,作为整张图片的表示

cls:使用第一个 token(即 cls token)作为整张图片的表示(和 BERT 类似)

输出

    x = self.to_latent(x)
    return self.mlp_head(x)
    
    
#########################################  
    # self.to_latent 定义
    self.to_latent = nn.Identity()

    # self.mlp_head 定义
    self.mlp_head = nn.Linear(dim, num_classes)
  1. to_latent: 在这边是 Identity 操作。同时也是一个预留的接口,允许用户在不修改原始代码的情况下,通过继承 ViT 类来扩展功能,比如可以修改这个 Identity 操作,在特征进入最后分类头之前进行额外处理

  2. mlp_head: 通过简单的线性层完成最终分类

总结核心流程

  1. 图像分块与特征提取:

    1. 将输入图像切分成固定大小的 patches
    2. 通过 to_patch_embedding 提取每个 patch 的特征表示(内容编码)
    3. 添加 pos_embedding 为每个 patch 注入位置信息(位置编码)
  2. 序列构建与增强:

    1. 添加可学习的 cls_token 作为全局特征聚合点;
    2. 将位置编码与特征编码相加,形成完整的 patch token;
    3. 最终得到包含空间和语义信息的序列表示
  3. Transformer 特征提取:

    1. 通过多头自注意力机制实现 patches 间的全局交互
    2. 使用 FeedForward 网络增强特征表示能力
    3. 采用残差连接和 LayerNorm 确保深层网络的训练稳定性
  4. 分类预测:

    1. 可选择使用 cls token 或平均池化获取全局特征

    2. 通过简单的线性层完成最终分类

参考

ViT 代码:github.com/lucidrains/…

ViT 论文:openreview.net/pdf?id=Yicb…

transformer 讲解:github.com/datawhalech…

BETR 讲解:book.douban.com/subject/362…

GELU 论文:arxiv.org/pdf/1606.08…