极智AI | 谈谈 ViT 的意义和实现细节

1,373 阅读9分钟

欢迎关注我的公众号 [极智视界],获取我的更多经验分享

大家好,我是极智视界,本文来谈一谈 ViT的意义和实现细节。

Transformer 是现在火热的 AIGC预训练大模型的基础,而 ViT (Vision Transformer) 是真正意义上将自然语言处理领域的 Transformer 带到了视觉领域。从 Transformer 的发展历程就可以看出,从 Transformer 的提出到将 Transformer 应用到视觉,其实中间蛰伏了三年的时间。而从将 Transformer 应用到视觉领域 (ViT) 到 AIGC 的火爆也差不多用了两三年。其实 AIGC 的火爆,从 2022年下就开始有一些苗条,那时就逐渐有一些 AIGC 好玩的算法放出来,而到现在,AIGC好玩的项目真是层出不穷。

ViT 就是为了证明 Transformer 也能够很好地应用于视觉领域,ViT 最大的特色是在图像提特征部分(模型 backbone) 直接使用了 Transformer 中的 Encoder,而模型的输入也是在模仿 NLP 中的输入方式,把二维的图像数据输入打 patch 后展平成一维数据 (做成 sequence 的形式) 喂给模型。当然,其实如果是为了得到更加 SOTA 的结果,我想作者也没必要这么"死脑筋",完全还可以基于 Transformer Encoder 结构本身针对视觉任务做一些 "因地制宜" 的优化,这从 ViT 出来后马上涌现大量的视觉 Transformer 变体就可以看出来。那么 ViT 的作者为什么这么轴呢,为什么就不直接进一步再做优化呢,这大概率就是为了证明:原生的 Transformer 模型不做任何改变,也能在图像领域 work 得很好,而这本身带来的意义完全不亚于做出一篇性能 SOTA 的工作

为了佐证上面的观点,拿出一组当时的性能对比数据。从下面这张图像分类任务的实验对比数据也可以看出来,ViT 本身的精度提升相比于传统的 ResNet 来说,并不太明显,况且还是用最大的 14 的模型的情况下才有一些提升。所以从模型性能 SOTA 的角度来说,我觉得 ViT 表现的成绩并不够惊艳,更何况它还是用在视觉任务最简单的分类任务上 (作为对比,"丧心病狂"的 SAM(Segment Anything Model) 一出来就出在相对更加难的分割任务上(视觉任务难度系数:分割 > 检测 > 分类))。所以从这个角度来说,ViT之所以重要,它的意义根本不在性能的 SOTA,而在于它给 Transformer 的多模态应用打了个头,为 Transformer 的大一统打了个头

对于习惯了卷积神经网络的视觉同学,或者是习惯了 Transformer 的语言同学来说,可能都会对 ViT 中出现的一些超参有一些疑惑,毕竟 ViT 是将本身在语言领域已经用的很好的 Transformer 迁移到了视觉领域,所以其中势必会涉及到两个模态的专用领域知识。

下面来讲讲 ViT 的实现细节。

得益于社区众多大佬的贡献,目前使用 pytorch 调用 ViT 的方式十分的简洁,但是当你初来乍到,看到这么多似懂非懂的传参后,可能也会 "一脸懵逼",啥是 patch_size、啥又是 heads ...

下面对 ViT 中的各个主要的超参进行解释:

其中,

  • patch_size:表示输入图像被分成的小块的大小,一般为 16x16,且 image_size 需要能够被path_size 整除。需要注意的是,path_size 大小一般也被用于标识 ViT 模型的大小,比如是 ViT-Large 还是 ViT-Huge,比如上面实验数据表格中的 ViT-L/16 就表示 ViT-Large path_size 为 16。所以 patch_size 其实一定程度上能够反应模型的大小,一般来说,patch_size 越小,那么 patch 块就会越多,模型相应也会越大,而这同样适用于其他一些 Transformer 变体的预训练模型大小的表达,比如 CLIP、Swin-Transformer 等
  • num_layers:表示 transformer 模块的层数,一般为 12,就是模型结构中 Block * Nx 中 Nx 的数量;
  • embedding_dim:表示每个 patch 被嵌入到向量空间中后的维度,一般为 768,也就是sequence 序列的长度;
  • num_heads:表示 transformer 模块中多头注意力机制的头数,一般为 12,且 embedding_dim需要能够被 num_heads 整除;
  • mlp_dim:表示 transformer 模块中全连接层的维度,一般为 3072;

这样就把 ViT 的一些超参解释清楚了,但是现在你可能还不是很清楚 ViT 的计算流 (work flow) 是怎么样的,还是先来看一下 ViT 的模型结构:

从 ViT 模型结构可以分析出算法的整体流程是:图片 -> Patch Embedding(打patch) -> Postion Embedding(位置嵌入) -> Transformer Encoder (基础block:(Multi-Head Attention -> MLP) * Nx) -> MLP Head(分类头) -> 分类。但这其实也只是说了个大概,特别是对于 ViT 核心 backbone 来说 (也就是上图中的右边部分),一个更加清晰的表达如下:

在上面的计算流图中,上部分大块黄色表示 Multi-Head Attention,下部分细长黄色块表示 MLP,两个黄色块之间用残差连接进行 Add (Kaiming大佬的残差真的是神器,它能减缓梯度消失,让网络可以做的很深。残差在 Transformer 中的成功应用,说明了它不仅仅只能用于卷积神经网络,泛化能力十足)。

对于模型的实际部署来说,工程中为了得到更高的吞吐和更低的算力消耗,我们其实经常会选择采用更低比特的精度去部署,比如 int8 量化。做过 int8 量化的同学应该清楚,一般需要在网络中插入量化节点(Quant) 和 解量化节点(deQuant),对于 ViT 的量化来说,同样如此。下图展示了 ViT int8 量化的计算流图,相较于上面 fp32/fp16 的推理精度的计算流图来说,正是插入了不少的量化和解量化节点,但是整体流程还是差不多的。当然量化节点也不是乱插的,对于一般的对称量化来说,其实会选择不对非线性的激活函数进行量化处理,比如这里的 softmax。而对于计算密集型的 GEMM 矩阵乘,会是量化优化的重点算子。

接着来看看 ViT 中几个关键组件的 pytorch 代码实现。

PatchEmbed => 图片块打patch 16163->786

首先是要把图片打成 Patch,就是要将图像均分成不重叠的小块。

从代码实现来看,其实很简单,直接用一个 kernel_size 和 stride 都等于 patch_size 的二维卷积来做就行,然后用 flatten 来展平。为什么要用 kernel_size 和 stride 都等于 patch_size 的二维卷积来做 patch 呢,原因有几点:

  • (1) 卷积对于部署"友好" => 现代的AI硬件对于卷积已经做了太多成熟的优化方案,可以方便的进行推理加速;
  • (2) 卷积实现代码简洁 => 其实你仔细想想,图像均分成小块的方式肯定有多种,比如切片,那样写代码看上去就不好看了,写起来也比较复杂 (不够优雅);
class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        # 14,14
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        # 14 * 14 = 196
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten
        # input=3, output_dim=768, kernel_size=16, 224 * 224 * 3 -》 14 * 14 * 768(16 * 16 *3) 
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC  B*768*14*14 -》 B*196*768
        x = self.norm(x)
        return x

Attention => Multi-Head Attention

Multi-Head Attention 就跟 Transformer 中的一毛一样了,在里面会对 Q 跟 K转置 的矩阵乘后除以根号dk用以缩放。可以看到听起来比较神秘的 Q、K、V,在代码实现来说也极其简单,直接用一个大家很熟悉的 Linear + reshape 来实现。需要注意的是,这里用 Linear 来生成 Q、K、V 的原因是因为它是自注意力,可以理解为将同一个输入复制成三份分别赋给 Q、K、V,注意如果不是自注意力,就不是这么做了。比如 Transformer 中的 Decoder 模块中的 Multi-Head Attention,就不是自注意力 (K、V 来自于 Transformer Encoder 的输出,Q 来自于 Transformer Decoder 本身),就不能直接这么写了。同样反过来,正是因为 ViT 中只是用于了 Transformer 中的 Encoder,所以 ViT 中也就只有自注意力机制。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        # 一个linear计算算qkv
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Block => LN + Attention + dropout + LN + MLP

一个 Block = LN + Attention + dropout + LN + MLP, 中间会有 2 次残差链接,这就是 ViT Encoder (也是 Transformer Encoder) 的一个基本的结构了,而这个基本结构在 ViT 中需要循环做 12 次。

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

ViT => 合体

下面来看 ViT 的完整的 pytorch class 实现,就是对上面几个关键的组件进行拼装而成。

class VisionTransformer(nn.Module):
    """ Vision Transformer

    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
        - https://arxiv.org/abs/2010.11929

    Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
        - https://arxiv.org/abs/2012.12877
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None, weight_init=''):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
            weight_init: (str): weight init scheme
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        self.init_weights(weight_init)

    def forward_features(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])  # x must be a tuple
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return

好了,以上分享了 谈谈 ViT 的意义和实现细节。希望我的分享能对你的学习有一点帮助。



logo_show.gif