多头自注意力机制深度解析与11个关键洞见

111 阅读4分钟

多头自注意力机制的工作原理:数学、直觉与10+1个隐藏洞见

自注意力作为两个矩阵乘法

数学原理

考虑不带多头机制的自我点积注意力以增强可读性。给定输入 XRbatch×tokens×dmodel\textbf{X} \in \mathcal{R}^{batch \times tokens \times d_{model}},以及可训练权重矩阵:WQ,WK,WVRdmodel×dk\textbf{W}^{Q}, \textbf{W}^{K}, \textbf{W}^{V} \in \mathcal{R}^{d_{\text{model}} \times d_{k}}

我们创建3个不同的表示(查询、键和值): Q=XWQ,K=XWK,V=XWV,Rbatch×tokens×dk\textbf{Q} = \textbf{X} \textbf{W}^Q, \textbf{K} = \textbf{X} \textbf{W}^K, \textbf{V} = \textbf{X} \textbf{W}^V , \mathcal{R}^{batch \times tokens \times d_{k}}

然后定义注意力层为: Y=Attention(Q,K,V)=softmax(QKTdk)V\textbf{Y} = \operatorname{Attention}(\textbf{Q}, \textbf{K}, \textbf{V})=\operatorname{softmax}\left(\frac{\textbf{Q} \textbf{K}^{T}}{\sqrt{d_{k}}}\right) \textbf{V}

点积得分计算为: Dot-scores=(QKTdk)\operatorname{Dot-scores} = \left(\frac{\textbf{Q} \textbf{K}^{T}}{\sqrt{d_{k}}}\right)

直观图解

通过将所有查询放在一起,我们有一个矩阵乘法,而不是每次单个查询向量到矩阵的乘法。每个查询完全独立于其他查询进行处理。

查询-键矩阵乘法

基于内容的注意力具有不同的表示。注意力层中的查询矩阵在概念上是数据库中的"搜索"。键将决定我们查看的位置,而值将实际提供所需的内容。

直观上,键是查询(我们正在寻找的内容)和值(我们实际将获得的内容)之间的桥梁。

注意力V矩阵乘法

然后使用权重 αij\alpha_{ij} 来获得最终的加权值。

原始Transformer的交叉注意力

相同的原理适用于编码器-解码器注意力(或称交叉注意力)。键和值通过对最终编码输入表示的线性投影计算,经过多个编码器块后。

多头注意力详细工作原理

将注意力分解为多个头是并行和独立计算的第二部分。原始多头注意力定义为:

 MultiHead (Q,K,V)= Concat (head 1,, head h)WO\text { MultiHead }(\textbf{Q}, \textbf{K}, \textbf{V}) =\text { Concat (head }_{1}, \ldots, \text { head } \left._{\mathrm{h}}\right) \textbf{W}^{O}

其中  head i= Attention (QWiQ,KWiK,VWiV)\text { head }_{\mathrm{i}} =\text { Attention }\left(\textbf{Q} \textbf{W}_{i}^{Q}, \textbf{K} \textbf{W}_{i}^{K},\textbf{V} \textbf{W}_{i}^{V}\right)

基本上,初始嵌入维度 dimdim 被分解为 h×dheadh \times d_{head},每个头的计算独立进行。

自注意力独立计算的并行化

所有表示都是从相同的输入创建的,并合并在一起以产生单个输出。然而,各个 Qi,Ki,ViQ_{i}, K_{i}, V_{i} 处于较低的维度 dk=dmodel/headsd_k = d_{model}/heads

注意力机制的洞见和观察

自注意力不是对称的!

如果我们进行数学计算,就很容易理解: QKTdk=XWQ(XWK)Tdk=XWQWKTXTdk\frac{\textbf{Q} \textbf{K}^{T}}{\sqrt{d_{k}}} = \frac{\textbf{X} \textbf{W}_Q (\textbf{X} \textbf{W}_K)^{T}}{\sqrt{d_{k}}} = \frac{\textbf{X} \textbf{W}_Q \textbf{W}_K^{T} \textbf{X}^T }{\sqrt{d_{k}}}

为了使自注意力对称,我们必须对查询和键使用相同的投影矩阵:WQ=WK\textbf{W}_Q = \textbf{W}_K

注意力作为多个本地信息的路由

多头允许头部关注输入的不同部分,但研究表明头部保留了几乎所有的内容。这使得注意力成为查询序列相对于键/值的路由算法。

编码器权重可以高效分类和剪枝

研究发现可以通过查看注意力矩阵来识别3种重要类型的头部:

  • 主要关注邻居的位置头部
  • 指向具有特定语法关系的标记的语法头部
  • 指向句子中罕见词的头部

通过保留分类为显著类别的头部,可以在保持几乎相同BLEU分数的情况下保留48个头部中的17个。

头部共享共同投影

研究表明,即使每个头的权重矩阵的单独乘积不是低秩的,它们连接后的乘积是低秩的。这实际上意味着头部共享共同投影。

编码器-解码器注意力中的多头非常重要

研究表明,当从不同的注意力子模型中逐步剪枝头部时,编码器-解码器注意力层的性能下降得更快。剪枝超过60%的交叉注意力头部会导致显著的性能下降。

应用softmax后,自注意力是低秩的

研究表明,在应用softmax后,所有层的自注意力都是低秩的。这意味着PP中包含的大部分信息可以从前几个最大奇异值中恢复。

注意力权重作为快速权重记忆系统

通过移除著名的注意力机制中的softmax,我们得到了类似的行为。值和键的外积可以被视为快速权重。

秩崩溃和标记均匀性

研究发现,没有额外组件(如MLP和跳跃连接)的情况下,注意力会指数级收敛到秩1矩阵。

层归一化:迁移学习预训练Transformer的关键成分

研究发现,层归一化可训练参数(0.1%的参数)对于在低数据机制中微调Transformer至关重要。

二次复杂度:我们到了吗?

减少二次复杂度的研究主要有两类:

  • 使用数学近似完全二次全局注意力的方法
  • 尝试限制和稀疏化注意力的方法

结论

经过这么多视角和观察,希望您在基于内容的注意力分析中至少获得了一个新的洞见。如此简单的想法能有如此巨大的影响和如此多的意义和洞见,真是令人惊叹。