深度学习理论-直观理解 Attention

361 阅读1分钟

本文首先介绍 Attention 的原始公式,然后以 Self-Attention 为例,简化后逐步分析 Attention 计算结果表达的含义

Attention

Attention 公式如下:

Attention=softmax(QKTdk)VAttention = softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V

其中 softmax 作用是归一化,公式如下:

softmax(x)=exi=1nexisoftmax(x) = \frac{e^x}{\sum_{i=1}^n{e^{x_i}}}

我们将 QKTdk\frac{Q \cdot K^T}{\sqrt{d_k}} 称为 attention score,归一化后 softmax(QKTdk)softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) 称为 attention weight

Self-Attention

在 Self-Attention 中,输入为 XX ,乘以不同的权重矩阵,就得到了不同的 QQKKVV

Q=XWqQ = X \cdot W_q

K=XWkK = X \cdot W_k

V=XWvV = X \cdot W_v

为了方便理解,我们先做简化,把权重矩阵 Wq,Wk,WvW_q, W_k, W_v 和缩放因子 dk\sqrt{d_k} 都假设为 1

简化后,Self-Attention 长这样

softmax(XXT)Xsoftmax(X\cdot X^T) \cdot X

1. Attention Score

首先来看 XXTX \cdot X^T 的含义,我们先复习一下,向量内积表示的是两个向量的相关性

假设 XXn×dn \times d 的矩阵,nn 是输入的数量,dd 是特征的维度, XXTX \cdot X^Tn×nn \times n 的矩阵,表示输入的每个元素,与其它元素的相关性

2. Attention Weight

softmax 就是做归一化,使得权重的和为 1,表达的含义跟 score 一致,相关性高的权重也高,通过非线性函数 exe^x 后,变成了概率分布

3. Attention Value

用相关性矩阵乘以输入向量,得到了 n×dn \times d 的矩阵,跟输入 XX 的尺度一致,含义也一致,依然表示 nn 个输入变量 的 dd 维特征,但这个特征已经是经过注意力加权的特征,相关性更高的元素响应更高。

看到这里,相信你对 attention 机制已经有了直观的理解。下面就把之前简化的细枝末节加回来。

4. 权重矩阵

权重矩阵 Wq,Wk,WvW_q, W_k, W_v 都是可训练的参数,具有以下作用

  • 使用不同的权重矩阵,可以提升模型的表达能力。
  • 通过调整 Wq,WkW_q, W_k​,模型可以学习到不同的注意力模式,使得某些输入 token 之间的关联更强或更弱。
  • 权重矩阵 WvW_v​ 影响模型如何聚合信息,使得某些 token 在最终表示中占更重要的比重。

5. 缩放因子

dk​​ 作为缩放因子有如下两个作用:

5.1 防止数值过大,避免梯度消失或梯度爆炸

  • QKTQK^T 是两个向量的点积,其值范围随着维度 dkd_k​ 增大而增大。
  • 如果 不除以 dk\sqrt{d_k}​​,那么较大维度时,点积的数值会变得非常大,导致 softmax 结果变得极端(接近 0 或 1),从而导致梯度消失,影响训练稳定性。
  • 除以 dkdk​​ 后,使得点积值的范围保持在适当区间,从而让 softmax 更平滑。

5.2 保持不同维度下的数值稳定性

  • 在深度学习中,通常希望输入数据的方差保持在一个稳定范围,否则网络难以收敛。
  • 设 Query 和 Key 向量的分量服从均值为 0、方差为 1 的标准正态分布,则点积的期望和方差为: E[QKT]=0,Var[QKT]=dkE[QK^T]=0,Var[QK^T]=d_k
  • 这意味着 dk\sqrt{d_k}​​​ 越大,点积值的方差也会随之增大,从而影响 softmax 的输出。
  • 除以 dk\sqrt{d_k}​​ 后,点积值的方差变为 1,保持了数值稳定性,使不同维度的注意力机制都能较好地工作。

复杂度

假设 QQ KK VV 的维度分别为 N×dN \times dM×dM \times dM×dM \times d

  • 时间复杂度O(NMd)O(NMd)

    • 计算 Q, K, V: O(Nd2)O(Nd^2)  O(Md2)O(Md^2) O(Md2)O(Md^2)(线性变换)
    • 计算 Attention-Score QKTQK^TO(NMd)O(NMd)
    • 计算 Softmax: O(NM)O(NM)
    • 计算加权求和 softmax(QKT)Vsoftmax(QK^T)VO(NMd)O(NMd)
    • 总体上,主要瓶颈是 QKTQK^T 的计算和加权求和,因此时间复杂度为 O(NMd)O(NMd)
  • 空间复杂度O(NM)O(NM)

    • 由于需要存储 QKTQK^T(一个 N×MN×M 的矩阵),因此空间复杂度是 O(NM)O(NM)

对于 Self-Attention,由于 N=MN=M,时间复杂度为O(N2d)O(N^2d),空间复杂度为 O(N2)O(N^2)

思考

如果想屏蔽某些特征,应该如何做?mask 是怎样实现的?

Google T5 不除以 dk\sqrt{d_k} 为什么也能够收敛?

参考资料