第03章:Scaled Dot-Product Attention——那个√d_k到底在防什么?

0 阅读2分钟

第03章:Scaled Dot-Product Attention——那个√d_k到底在防什么?

论文链接Attention Is All You Need (Vaswani et al., NIPS 2017)
本章对应:Section 3.2.1, Footnote 4

核心困惑

为什么Attention公式里要除以dk\sqrt{d_k}?不除会怎样?

这个dk\sqrt{d_k}看起来不起眼,但它在防一个致命问题。如果你在面试时被问"Scaled Dot-Product Attention的Scaled是什么意思",答不出来就直接挂了。

完整的Attention公式是: Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

为什么是dk\sqrt{d_k}而不是dkd_k或者dk2d_k^2?这背后有严格的数学推导。


前置知识补给站

1. 向量点积的几何意义

两个向量qqkk的点积: qk=i=1dkqiki=qkcosθq \cdot k = \sum_{i=1}^{d_k} q_i k_i = \|q\| \|k\| \cos\theta

几何意义

  • 点积衡量两个向量的相似度
  • 点积越大,向量越相似(夹角越小)
  • 点积可以是负数(夹角 > 90°)

2. 随机变量的方差

对于随机变量XXVar(X)=E[(XE[X])2]=E[X2](E[X])2\text{Var}(X) = E[(X - E[X])^2] = E[X^2] - (E[X])^2

方差的性质

  • Var(aX)=a2Var(X)\text{Var}(aX) = a^2 \text{Var}(X)
  • 如果XXYY独立:Var(X+Y)=Var(X)+Var(Y)\text{Var}(X + Y) = \text{Var}(X) + \text{Var}(Y)

3. Softmax函数的饱和区

Softmax函数: softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

饱和区问题

  • 当某个xix_i远大于其他xjx_j时,softmax(xi)1\text{softmax}(x_i) \approx 1,其他位置0\approx 0
  • 此时梯度softmaxxj0\frac{\partial \text{softmax}}{\partial x_j} \approx 0(对于所有jj
  • 这叫"饱和",会导致梯度消失

论文精读:为什么需要缩放?

原论文的解释

Section 3.2.1

"We call our particular attention 'Scaled Dot-Product Attention'. The input consists of queries and keys of dimension dkd_k, and values of dimension dvd_v. We compute the dot products of the query with all keys, divide each by dk\sqrt{d_k}, and apply a softmax function to obtain the weights on the values."

Footnote 4(关键):

"We suspect that for large values of dkd_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by 1dk\frac{1}{\sqrt{d_k}}."

翻译成人话:

  1. dkd_k很大时,点积的值会很大
  2. 大的点积会让Softmax进入饱和区
  3. 饱和区的梯度极小,导致梯度消失
  4. 除以dk\sqrt{d_k}可以把点积的值控制在合理范围

但原论文没有证明"为什么是dk\sqrt{d_k}"。我们来严格推导。


第一性原理推导:为什么是√d_k?

推导1:随机向量点积的方差

假设(理想化数学模型):

  • qqkkdkd_k维随机向量
  • q=xWQq = xW^Qk=xWKk = xW^K,其中WQW^QWKW^K在初始化时独立
  • 每个分量qi,kiq_i, k_i的均值为0,方差为1
  • qq的各分量不相关,kk的各分量不相关

:实践中这些假设通过LayerNorm近似满足。

目标:计算点积qkq \cdot k的方差。

推导: 首先计算期望: qk=i=1dkqikiE[qk]=i=1dkE[qiki]\begin{aligned} q \cdot k &= \sum_{i=1}^{d_k} q_i k_i \\ E[q \cdot k] &= \sum_{i=1}^{d_k} E[q_i k_i] \end{aligned}

在初始化时,WQW^QWKW^K独立初始化,因此对于固定的输入xxqiq_ikik_i条件独立: E[qikix]=E[qix]E[kix]E[q_i k_i | x] = E[q_i | x] \cdot E[k_i | x] 由于投影后的向量均值为0(假设),因此: E[qk]=0E[q \cdot k] = 0

接下来计算方差: Var(qk)=E[(qk)2](E[qk])2=E[(qk)2]=E[(i=1dkqiki)2]=E[i=1dkqi2ki2+ijqikiqjkj]\begin{aligned} \text{Var}(q \cdot k) &= E[(q \cdot k)^2] - (E[q \cdot k])^2 \\ &= E[(q \cdot k)^2] \\ &= E\left[\left(\sum_{i=1}^{d_k} q_i k_i\right)^2\right] \\ &= E\left[\sum_{i=1}^{d_k} q_i^2 k_i^2 + \sum_{i \neq j} q_i k_i q_j k_j\right] \end{aligned}

对于iji \neq j的交叉项,由于qqkk来自不同的投影矩阵: E[qikiqjkj]=E[(qiqj)(kikj)]=E[qiqj]E[kikj]E[q_i k_i q_j k_j] = E[(q_i q_j)(k_i k_j)] = E[q_i q_j] \cdot E[k_i k_j] 如果qq的各分量不相关且均值为0,则E[qiqj]=Cov(qi,qj)+E[qi]E[qj]=0E[q_i q_j] = \text{Cov}(q_i, q_j) + E[q_i]E[q_j] = 0iji \neq j)。同理E[kikj]=0E[k_i k_j] = 0。因此交叉项期望为0。

对于对角项: i=1dkE[qi2ki2]=i=1dkE[qi2]E[ki2](qk独立)=i=1dkVar(qi)Var(ki)(均值为0)=i=1dk11=dk\begin{aligned} \sum_{i=1}^{d_k} E[q_i^2 k_i^2] &= \sum_{i=1}^{d_k} E[q_i^2] E[k_i^2] \quad \text{(}q \text{和} k \text{独立)} \\ &= \sum_{i=1}^{d_k} \text{Var}(q_i) \cdot \text{Var}(k_i) \quad \text{(均值为0)} \\ &= \sum_{i=1}^{d_k} 1 \cdot 1 \\ &= d_k \end{aligned}

因此: Var(qk)=dk\text{Var}(q \cdot k) = d_k 结论:点积qkq \cdot k的方差是dkd_k
标准差Var(qk)=dk\sqrt{\text{Var}(q \cdot k)} = \sqrt{d_k}

归一化:如果我们除以dk\sqrt{d_k}Var(qkdk)=1dkVar(qk)=1dkdk=1\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{1}{d_k} \text{Var}(q \cdot k) = \frac{1}{d_k} \cdot d_k = 1 这就是为什么是dk\sqrt{d_k}:它把点积的方差归一化到1,无论dkd_k多大。


推导2:Softmax饱和区的数值演示

问题:为什么点积过大会导致Softmax饱和?

Softmax的梯度softmax(xi)xj={softmax(xi)(1softmax(xi))if i=jsoftmax(xi)softmax(xj)if ij\frac{\partial \text{softmax}(x_i)}{\partial x_j} = \begin{cases} \text{softmax}(x_i)(1 - \text{softmax}(x_i)) & \text{if } i = j \\ -\text{softmax}(x_i) \text{softmax}(x_j) & \text{if } i \neq j \end{cases}

数值示例: 假设有3个位置,点积为[x1,x2,x3][x_1, x_2, x_3]

  • 情况1:点积适中(dk=64d_k=64,已缩放) x=[1.0,0.5,0.5]x = [1.0, 0.5, -0.5] softmax(x)=[0.506,0.307,0.187]\text{softmax}(x) = [0.506, 0.307, 0.187] softmax(x1)x1=0.506×(10.506)=0.250\frac{\partial \text{softmax}(x_1)}{\partial x_1} = 0.506 \times (1 - 0.506) = 0.250 梯度正常,可以学习。

  • 情况2:点积过大(dk=64d_k=64,未缩放) x=[8.0,4.0,4.0]x = [8.0, 4.0, -4.0] softmax(x)=[0.9997,0.0003,0.0000]\text{softmax}(x) = [0.9997, 0.0003, 0.0000] softmax(x1)x1=0.9997×(10.9997)=0.0003\frac{\partial \text{softmax}(x_1)}{\partial x_1} = 0.9997 \times (1 - 0.9997) = 0.0003 梯度几乎为0,无法学习。

可视化对比

dkd_k标准差3σ3\sigma 范围Softmax状态梯度
64(已缩放)1[-3, 3]✅ 正常~0.25
64(未缩放)8[-24, 24]❌ 饱和~0.0003
512(未缩放)22.6[-68, 68]❌ 极饱和~0.0

推导3:为什么不是其他缩放因子?

候选方案

  1. 除以 dkd_kVar(qk/dk)=1/dk\text{Var}(q \cdot k / d_k) = 1/d_k,方差随维度增加而消失,分布太集中。
  2. 除以 dk\sqrt{d_k}Var(qk/dk)=1\text{Var}(q \cdot k / \sqrt{d_k}) = 1,方差刚好稳定。
  3. 除以 logdk\log d_kVar(qk/logdk)=dk/(logdk)2\text{Var}(q \cdot k / \log d_k) = d_k / (\log d_k)^2,方差仍然随 dkd_k 快速增长。

只有dk\sqrt{d_k}能把方差归一化到1


Scaled vs Unscaled的实验对比

原论文没有直接对比Scaled和Unscaled的实验(因为缩放的必要性可以通过数学推导证明)。但Table 3 row (B)的消融实验说明:在缩放了的情况下,dkd_k的维度选择仍然重要。

原论文Table 3 row (B)

  • dk=16d_k=16: PPL 5.16, BLEU 25.1
  • dk=32d_k=32: PPL 5.01, BLEU 25.4
  • dk=64d_k=64 (base): PPL 4.92, BLEU 25.8

解读

  • dkd_k越大,效果越好(在缩放了的情况下)。
  • 这说明更大的dkd_k能提供更丰富的表示能力。
  • 但如果不缩放,大dkd_k会导致Softmax饱和,效果反而变差。

Dot-Product Attention vs Additive Attention

原论文提到了两种Attention机制:

1. Dot-Product Attention(原论文使用)

score(q,k)=qk\text{score}(q, k) = q \cdot k 优点:计算高效(矩阵乘法);并行性好。 缺点:需要缩放(否则方差随dkd_k增长)。

2. Additive Attention(Bahdanau et al., 2015)

score(q,k)=vTtanh(Wqq+Wkk)\text{score}(q, k) = v^T \tanh(W_q q + W_k k) 优点:不需要缩放(tanh自带归一化);理论表达力强。 缺点:计算慢(两次矩阵乘法 + tanh);参数量大。

结论:缩放后的Dot-Product快且效果相当,是工业界的首选。


完整的Scaled Dot-Product Attention流程

  1. 计算相似度QKTQK^T,得到 n×mn \times m 的score矩阵。
  2. 缩放:除以 dk\sqrt{d_k},把方差归一化到 1。
  3. Mask(可选):先Scale后Mask是工程习惯(Mask值通常为 109-10^9)。
  4. 归一化:Softmax,把score转为概率分布。
  5. 加权求和:用概率加权 VV,得到输出。

2026年的批判性视角

  1. 缩放因子的理论假设:LayerNorm 是保证 q,kq, k 符合均值0、方差1假设的关键。
  2. 其他缩放方案:如 T5 使用的可学习缩放,或 ALiBi 使用的位置偏置。
  3. Softmax 替代品:ReLU Attention 或 Linear Attention。
  4. 确定性 vs 统计性:相比 BatchNorm,除以 dk\sqrt{d_k} 是确定性的,在推理时更稳定。

面试追问清单

  1. 为什么Attention要除以dk\sqrt{d_k} (提示:点积方差推导)
  2. 如果不除以dk\sqrt{d_k}会怎样? (提示:Softmax饱和、梯度消失)
  3. 证明:若 qi,ki(0,1)q_i, k_i \sim (0, 1),则 Var(qk)=dk\text{Var}(q \cdot k) = d_k (提示:展开并利用独立性)
  4. Dot-Product 相比 Additive Attention 的优势? (提示:计算效率)
  5. LayerNorm 在这里起到了什么作用? (提示:维持分布假设)

下一章预告:第04章将深入拆解Multi-Head Attention,回答"八个头,八个视角,还是八份低秩分解?"