多头自注意力机制的工作原理:数学、直觉与10+1个隐藏洞见
自注意力作为两个矩阵乘法
数学原理
考虑不带多头机制的自我点积注意力以增强可读性。给定输入 X∈Rbatch×tokens×dmodel,以及可训练权重矩阵:WQ,WK,WV∈Rdmodel×dk。
我们创建3个不同的表示(查询、键和值):
Q=XWQ,K=XWK,V=XWV,Rbatch×tokens×dk
然后定义注意力层为:
Y=Attention(Q,K,V)=softmax(dkQKT)V
点积得分计算为:
Dot-scores=(dkQKT)
直观图解
通过将所有查询放在一起,我们有一个矩阵乘法,而不是每次单个查询向量到矩阵的乘法。每个查询完全独立于其他查询进行处理。
查询-键矩阵乘法
基于内容的注意力具有不同的表示。注意力层中的查询矩阵在概念上是数据库中的"搜索"。键将决定我们查看的位置,而值将实际提供所需的内容。
直观上,键是查询(我们正在寻找的内容)和值(我们实际将获得的内容)之间的桥梁。
注意力V矩阵乘法
然后使用权重 αij 来获得最终的加权值。
原始Transformer的交叉注意力
相同的原理适用于编码器-解码器注意力(或称交叉注意力)。键和值通过对最终编码输入表示的线性投影计算,经过多个编码器块后。
多头注意力详细工作原理
将注意力分解为多个头是并行和独立计算的第二部分。原始多头注意力定义为:
MultiHead (Q,K,V)= Concat (head 1,…, head h)WO
其中 head i= Attention (QWiQ,KWiK,VWiV)
基本上,初始嵌入维度 dim 被分解为 h×dhead,每个头的计算独立进行。
自注意力独立计算的并行化
所有表示都是从相同的输入创建的,并合并在一起以产生单个输出。然而,各个 Qi,Ki,Vi 处于较低的维度 dk=dmodel/heads。
注意力机制的洞见和观察
自注意力不是对称的!
如果我们进行数学计算,就很容易理解:
dkQKT=dkXWQ(XWK)T=dkXWQWKTXT
为了使自注意力对称,我们必须对查询和键使用相同的投影矩阵:WQ=WK。
注意力作为多个本地信息的路由
多头允许头部关注输入的不同部分,但研究表明头部保留了几乎所有的内容。这使得注意力成为查询序列相对于键/值的路由算法。
编码器权重可以高效分类和剪枝
研究发现可以通过查看注意力矩阵来识别3种重要类型的头部:
- 主要关注邻居的位置头部
- 指向具有特定语法关系的标记的语法头部
- 指向句子中罕见词的头部
通过保留分类为显著类别的头部,可以在保持几乎相同BLEU分数的情况下保留48个头部中的17个。
头部共享共同投影
研究表明,即使每个头的权重矩阵的单独乘积不是低秩的,它们连接后的乘积是低秩的。这实际上意味着头部共享共同投影。
编码器-解码器注意力中的多头非常重要
研究表明,当从不同的注意力子模型中逐步剪枝头部时,编码器-解码器注意力层的性能下降得更快。剪枝超过60%的交叉注意力头部会导致显著的性能下降。
应用softmax后,自注意力是低秩的
研究表明,在应用softmax后,所有层的自注意力都是低秩的。这意味着P中包含的大部分信息可以从前几个最大奇异值中恢复。
注意力权重作为快速权重记忆系统
通过移除著名的注意力机制中的softmax,我们得到了类似的行为。值和键的外积可以被视为快速权重。
秩崩溃和标记均匀性
研究发现,没有额外组件(如MLP和跳跃连接)的情况下,注意力会指数级收敛到秩1矩阵。
层归一化:迁移学习预训练Transformer的关键成分
研究发现,层归一化可训练参数(0.1%的参数)对于在低数据机制中微调Transformer至关重要。
二次复杂度:我们到了吗?
减少二次复杂度的研究主要有两类:
- 使用数学近似完全二次全局注意力的方法
- 尝试限制和稀疏化注意力的方法
结论
经过这么多视角和观察,希望您在基于内容的注意力分析中至少获得了一个新的洞见。如此简单的想法能有如此巨大的影响和如此多的意义和洞见,真是令人惊叹。