【CaiT】如何才能使VIT网络往更深层发展

900 阅读3分钟

论文: Going deeper with Image Transformers

代码: Deit

前言

  近些天综合看CNN 领域内的文章以及VIT领域内的文章,在比对这两大类模型设计的特点时有一篇文章提醒到我了,它与CNN领域内的里程碑网络结构RESNET有异曲同工之妙。《Going deeper with Image Transformers》这篇论文的出现,将我们带入了一个更加深入的探索,试图回答一个重要问题:我们是否能够用Transformer模型在图像处理中取得类似于在自然语言处理中的突破?

基础回顾:残差

  残差连接(Residual Connections)是一种在深度神经网络中引入的重要技术。其主要目的是用于解决深度神经网络训练中的梯度消失和梯度爆炸等问题,同时使得深层网络的训练更加容易和高效。

  在传统的深度神经网络中,信息从网络的输入一直传递到输出,经过多个隐藏层的变换。每个隐藏层都是通过非线性激活函数(如ReLU)对上一层输出进行变换,以学习抽取更高级别的特征表示。然而,随着网络层数的增加,梯度在反向传播过程中可能会逐渐变得非常小,导致梯度消失的问题,使得深层网络难以训练。

  残差连接的核心思想是引入跳跃连接,将上一层的输入直接与当前层的输出相加,而不是直接对上一层的输出进行变换。数学上,对于一个层 的输出y和该层的输入x,残差连接可以表示为:y = F(x) + x,其中F(x) 是该层的变换操作。如果F(x) 能够逼近一个恒等映射(即输出等于输入),那么这个层就成为了一个恒等映射。

残差连接在深度神经网络中的意义在于:

  1. 缓解了梯度消失问题,使得深层网络能够更容易训练。
  2. 提高了网络的训练效率,加速了收敛过程。
  3. 允许网络层数更深,从而有助于提取更丰富和抽象的特征表示。
  4. 降低了过拟合的风险,使得网络更具泛化能力。

CaiT核心贡献

  由于《Going deeper with Image Transformers》一文中借鉴了残差连接的思想,那么如何规范化或初始化残差区域的模块适用于VIT网络就成了这篇文章的一个核心点。作者在文中分析了不同初始化之间的相互作用之后,提出了一种有效的方法。形式上,在输出的每个残差块初始化接近于0(但不是等于0)。添加这个简单的层在每个残差块提高训练的动态性后,就可以训练得到更深更高容量的VIT网络,受益于深度,在文中指的是这种方法为LayerScale

image.png

# classes
class LayerScale(nn.Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:  # epsilon detailed in section 2 of paper
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

  在Transformers中作者提出了第二个核心贡献,如下图所示。它类似于编码器/解码器的体系结构,其中分离了涉及自注意的Transformers层。在补丁之间,从专用于提取conc的类注意层中将处理过的补J放入单个向量,以便将其输入线性。分类器这种明确的分离避免了指导的矛盾目标。处理类嵌入时的注意过程。

image.png

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
        self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None):
        b, n, _, h = *x.shape, self.heads

        context = x if not exists(context) else torch.cat((x, context), dim = 1)

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn)    # talking heads, pre-softmax

        attn = self.attend(dots)
        attn = self.dropout(attn)

        attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn)   # talking heads, post-softmax

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

结语

  从这篇论文中给大家展示了残差连接在VIT网络中的价值,也从另一个角度辅证了:想要网络更深,残差连接是一个不错的手段,以及在残差连接上做一些优化改进。