Scaled Dot-Product:那个根号 d_k 是怎么来的'

14 阅读4分钟

序言:一个除号背后藏着的整门数学课

很多人第一次读 Vaswani 2017 的公式时,都会卡在那一个 dk\sqrt{d_k} 上。

公式本身写得简洁:

Attention(Q,K,V)=softmax(QKTdk)V\operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right) V

但那个分母上的 dk\sqrt{d_k} 看起来像是凭空冒出来的常数。

「为什么是 dk\sqrt{d_k} 不是 dkd_k?」

「为什么不是 dmodel\sqrt{d_{\mathrm{model}}}?」

「为什么不是其他什么数?」

如果你只读论文的那一句话——"to counteract the effect of large dot products"——你会觉得这是一个 经验技巧

事实远不止此。

那个 dk\sqrt{d_k} 是一个 从概率论第一性原理推出来的、几乎别无选择的数字

它涉及:随机变量的方差加法、softmax 的饱和性、链式法则下的梯度衰减、训练动力学的稳定性——你能想到的关于「为什么神经网络能优化」的核心问题,全都串在这一个除号上。

本文要做的事情,就是把这个除号拆开——一步一步、不跳逻辑——告诉你:为什么是 dk\sqrt{d_k}?它到底拯救了什么?以及,到 2026 年,对这个 dk\sqrt{d_k} 的现代理解(包括 NTK 视角、FlashAttention 数值稳定性、Muon 优化器对 attention 的影响)有哪些新维度。

读完之后,你应该能在被人问到「为什么除以 dk\sqrt{d_k}」的时候,给出一个 5 分钟版本、一个 30 分钟版本、和一个「我可以为你推一遍」版本。

原文链接


一、问题缘起:先看不除会发生什么

1.1 复盘公式

第 13 篇我们看到:

attention(q,k,v)=softmax(qKT)V\operatorname{attention}(q, k, v) = \operatorname{softmax}(qK^{\mathsf{T}}) V

第 14 篇我们看到:self-attention 让每个 token 同时扮演 q、k、v。

但实际工程里写的、Vaswani 论文里写的、所有 PyTorch 实现里写的,都是:

attention(Q,K,V)=softmax(QKTdk)V\operatorname{attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right) V

那个 dk\sqrt{d_k} 是什么?为什么必须有?

我们做一个思想实验:dk\sqrt{d_k} 拿掉,看会发生什么。

1.2 一个具体的数值实验

dk=64d_k = 64

qqkk 是两个 dkd_k 维向量,每一维独立、均值 00、方差 11(就当它们是从标准正态采样)。

我们想知道:qkq \cdot k 的分布是什么?

qk=iqikiq \cdot k = \sum_i q_i k_i

每一项 qikiq_i k_i 是两个独立标准正态的乘积——均值是 00,方差是 1×1=11 \times 1 = 1。因为

Var(XY)=E[X2]E[Y2]E[X]2E[Y]2=110=1,\operatorname{Var}(XY) = \mathbb{E}[X^2]\mathbb{E}[Y^2] - \mathbb{E}[X]^2 \mathbb{E}[Y]^2 = 1 \cdot 1 - 0 = 1,

对独立零均值变量成立。

6464 个独立项相加:

E[qk]=0,Var(qk)=64,σ=8.\mathbb{E}[q \cdot k] = 0, \qquad \operatorname{Var}(q \cdot k) = 64, \qquad \sigma = 8.

所以 qkq \cdot k 的取值范围大概在 ±24\pm 243σ3\sigma)以内浮动。

1.3 把 q·k = 24 喂进 softmax

假设我们有 88 个 key,对应 88 个点积,碰巧最大那个是 2424,其它都接近 00

softmax([24,0,0,0,0,0,0,0])=?\operatorname{softmax}([24, 0, 0, 0, 0, 0, 0, 0]) = ?
e242.6×1010,e0=1.e^{24} \approx 2.6 \times 10^{10}, \qquad e^0 = 1.

归一化后:[1,0,0,,0][\approx 1, \approx 0, \approx 0, \ldots, \approx 0]——几乎纯 one-hot

这看起来不是好事吗?模型「下定决心」选了一个 token?

恰恰相反——这是一场灾难。

1.4 灾难的来源:梯度消失

来看 softmax 的 Jacobian。

p=softmax(s)p = \operatorname{softmax}(s),那么:

pisj=pi(δijpj)\frac{\partial p_i}{\partial s_j} = p_i (\delta_{ij} - p_j)

其中 δij\delta_{ij} 是 Kronecker delta(i=ji = j 时为 11,否则为 00)。

如果 pp 接近 one-hot,比如 p11p_1 \approx 1,其它 0\approx 0,那么:

p1s1=1(11)=0,p1sj=1(00)=0  (j1),pisj0  (i1).\frac{\partial p_1}{\partial s_1} = 1 \cdot (1 - 1) = 0, \qquad \frac{\partial p_1}{\partial s_j} = 1 \cdot (0 - 0) = 0 \; (j \neq 1), \qquad \frac{\partial p_i}{\partial s_j} \approx 0 \; (i \neq 1).

整个 Jacobian 几乎为零矩阵

这意味着:通过 softmax 反向传播的梯度被掐死了。

下游的 loss 想告诉 attention 层「你应该多看看 token 5」,但这个信号被 softmax 饱和性吃掉了——logits ss 几乎不会被更新。

1.5 这就是 dk\sqrt{d_k} 要解决的问题

如果 logits 的方差是 dkd_k,那把它们除以 dk\sqrt{d_k},方差就变成 11

logits 不再因为维度放大而漂到饱和区,softmax 输出保持在「有梯度的工作点」。

训练就能进行。

这是 dk\sqrt{d_k} 的全部直觉——剩下的都是细节。


二、点积方差的严格推导

2.1 假设

我们在以下假设下推 Var(qk)=dk\operatorname{Var}(q \cdot k) = d_k

  1. qqkkdkd_k 维向量
  2. qq 的每一维 qiq_i 之间独立
  3. kk 的每一维 kjk_j 之间独立
  4. qqkk 之间也独立
  5. 所有 qiq_ikjk_j 都是均值 00、方差 11

后面我们会讨论这些假设在真实模型中成立到什么程度。

2.2 推导

X=qk=iqiki.X = q \cdot k = \sum_i q_i k_i.

第一步:均值。

E[X]=iE[qiki]=iE[qi]E[ki]=0.\mathbb{E}[X] = \sum_i \mathbb{E}[q_i k_i] = \sum_i \mathbb{E}[q_i] \mathbb{E}[k_i] = 0.

第二步:方差。

Var(X)=E[X2]E[X]2=E[X2].\operatorname{Var}(X) = \mathbb{E}[X^2] - \mathbb{E}[X]^2 = \mathbb{E}[X^2].
E[X2]=E[(iqiki)2]=ijE[qikiqjkj].\mathbb{E}[X^2] = \mathbb{E}\left[\left(\sum_i q_i k_i\right)^2\right] = \sum_i \sum_j \mathbb{E}[q_i k_i q_j k_j].

对于 iji \neq j

E[qikiqjkj]=E[qi]E[ki]E[qj]E[kj]=0.\mathbb{E}[q_i k_i q_j k_j] = \mathbb{E}[q_i] \mathbb{E}[k_i] \mathbb{E}[q_j] \mathbb{E}[k_j] = 0.

对于 i=ji = j

E[qi2ki2]=E[qi2]E[ki2]=1×1=1.\mathbb{E}[q_i^2 k_i^2] = \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] = 1 \times 1 = 1.
E[X2]=i1=dk.\mathbb{E}[X^2] = \sum_i 1 = d_k.
Var(X)=dk.\operatorname{Var}(X) = d_k.

2.3 标准差

σ=dk.\sigma = \sqrt{d_k}.

这就是 dk\sqrt{d_k} 的来源——不是任何「经验试出来的常数」,而是 Var\operatorname{Var} 加法的直接结果。

2.4 缩放后的分布

定义

X=Xdk.X' = \frac{X}{\sqrt{d_k}}.
Var(X)=Var(X)dk=1.\operatorname{Var}(X') = \frac{\operatorname{Var}(X)}{d_k} = 1.

不管 dkd_k6464512512、还是 40964096XX' 的方差永远是 11

logits 的尺度被「归一化」到了一个不依赖于维度的水平。

2.5 为什么是 dk\sqrt{d_k} 不是 dkd_k

有人问:「除以 dkd_k 不是更彻底吗?」

不行。

如果除以 dkd_k,那么

Var(Xdk)=dkdk2=1dk0\operatorname{Var}\left(\frac{X}{d_k}\right) = \frac{d_k}{d_k^2} = \frac{1}{d_k} \to 0

dkd_k 很大时就会发生这个退化。

logits 全部接近 0,softmax 输出接近均匀分布——attention 失去了「选择性」。

我们要的是「方差 =1= 1」(既不太尖锐也不太平),所以分母必须是 dk\sqrt{d_k}

这是一个临界点,不是一个「随便挑的数」。

2.6 有人问:为什么不除以 2dk\sqrt{2 d_k} 之类

也可以,只要常数因子合理(比如让方差 =0.5= 0.5 而不是 11)。

但这只会让 softmax 略偏平缓——本质和 dk\sqrt{d_k} 没区别。

Vaswani 选 dk\sqrt{d_k} 是因为它最自然——把方差归一化到 11,保留了「方差为 11」这个统计学最常用的标准化。

后续工作(比如 RoFormer、LLaMA)也都沿用这个选择。

2.7 一个数值表

dkd_kσ=dk\sigma = \sqrt{d_k}logits 范围(3σ3\sigma
82.83±8.5
325.66±17
648±24
12811.3±34
25616±48
51222.6±68

如果不缩放,512512 维的点积可能跑到 ±68\pm 68——softmax 看到这种 logits,对应的 e681029e^{68} \approx 10^{29}——任何对手 logit 都被压成 00

缩放后 logits 永远在 ±3\pm 3 左右——softmax 仍能区分大小,但梯度不会断流。


三、softmax 饱和性的可视化

softmax 尖锐度转存失败,建议直接上传图片文件

3.1 直观图景

右侧子图(unscaled):一个 logit 比其它大很多——softmax 集中在那一个点上。

左侧子图(scaled):logits 接近,softmax 平缓,多个 token 都有可见权重。

不是说「平缓就一定好」——实际训练中,模型最终会学到「需要 sharp 时 sharp」的能力。

但训练初期必须从「平缓 + 有梯度」的状态出发,否则模型一开始就被卡在饱和区出不来。

3.2 一个比喻

把 softmax 想象成一根弹簧。

弹簧未饱和时,你拉它它会动,反馈给你力——你能学到「拉的方向」。

弹簧饱和时(拉到极限),你怎么拉它都不动——你什么也学不到。

logits 越大,softmax 越饱和;scaled dot-product 就是把弹簧从「饱和区」拉回到「线性区」,让训练能进行。

3.3 数学复盘

softmax(s)i=esijesj.\operatorname{softmax}(s)_i = \frac{e^{s_i}}{\sum_j e^{s_j}}.

求导:

softmax(s)isj=pi(δijpj)\frac{\partial \operatorname{softmax}(s)_i}{\partial s_j} = p_i (\delta_{ij} - p_j)

最大值的对角项:pi(1pi)p_i (1 - p_i),当 pi1p_i \to 1 时为 00;当 pi0p_i \to 0 时也为 00;最大在 pi=0.5p_i = 0.5

非对角项:pipj-p_i p_j,仅当两个都不是 00 也不是 11 时才有效。

所以「梯度最大」的工作点是 pi[0.1,0.9]p_i \in [0.1, 0.9]——这正是 logits 适中(σ1\sigma \approx 1)时的状态。

3.4 与温度参数的关系

很多人熟悉「softmax 温度」TT

softmaxT(s)i=esi/Tjesj/T\operatorname{softmax}_T(s)_i = \frac{e^{s_i / T}}{\sum_j e^{s_j / T}}

TT 大,输出平缓;TT 小,输出 sharp。

scaled dot-product 中的 dk\sqrt{d_k} 就扮演温度的角色——具体来说 T=dkT = \sqrt{d_k}

但与 TT 不同,dk\sqrt{d_k} 不是调参,而是定参——它的值由维度决定,不是用户选择。

3.5 一个常见的混淆

「我可以学一个 temperature 替代 dk\sqrt{d_k} 吗?」

理论上可以,但实践中很少这么做。

因为 dk\sqrt{d_k} 已经把 logits 归一化到 σ1\sigma \approx 1,再学一个 temperature 等于多此一举——除非你想做「shaped attention」之类的研究。

LLaMA、GPT、PaLM 等都没有学习 temperature,全用 dk\sqrt{d_k}

但有一些工作(如 NormFormer、QK-norm)提出在 QQKK 上做 LayerNorm,再不除 dk\sqrt{d_k}——效果近似但实现略有不同。

到 2026 年,QK-norm 方案在大模型训练中越来越常见。


四、为什么这件事到 d_k = 64 才显著

4.1 一个有趣的现象

在最早的 attention 工作(Bahdanau 2014)中,用的是加性注意力——score=vTtanh(Wqq+Wkk)\mathrm{score} = v^{\mathsf{T}} \tanh(W_q q + W_k k)——根本没有 dk\sqrt{d_k} 的除法。

为什么 Bahdanau 不需要?

因为 Bahdanau 用的是 RNN 的 hidden state(典型 d=256d = 256 但走 tanh\tanh,输出落在 [1,1][-1, 1])+ 学习的 vv——score 永远在一个有限的 bounded 区间,不会因为维度爆炸。

dot-product attention(Luong 2015)开始有这个问题——因为 qkq \cdot k 没有 tanh\tanh 包住,方差直接随 dkd_k 增长。

但 Luong 的实验里 d 不大,问题不严重。

到 Vaswani 2017 multi-head 时代,dk=64d_k = 64(每 head 的维度),QQKK 的来源是线性投影后的向量——方差大约是 11(因为初始化 + LayerNorm)——这时候 qkq \cdot k 的方差就接近 6464,问题就显现出来了。

4.2 d_k 越大,问题越严重

到 GPT-3:dk=128d_k = 128(每 head),问题更严重。

到 PaLM:dk=256d_k = 256(每 head),不缩放训练直接发散。

Vaswani 的论文里有一段话:「我们怀疑对于大的 d_k 值,dot products 在量级上变大,从而把 softmax 推到具有极小梯度的区域。」

这是一句经验观察——他们看到了「不缩放训练崩」,做了缩放,发现「训练好了」。

后来的理论分析(Xiong 2020 "On Layer Normalization in the Transformer Architecture")才把这件事讲透。

4.3 为什么加性注意力没这个问题

vTtanh(Wqq+Wkk)v^{\mathsf{T}} \tanh(W_q q + W_k k) 中,tanh\tanh 的输出落在 [1,1][-1, 1]

随后 vT()v^{\mathsf{T}}(\cdot) 是一个 dd 维点积,这一步也会有方差放大。

但因为 tanh 已经把每一维 bound 住了,方差不会无界放大。

所以加性注意力天然「自带稳定性」,但代价是计算更慢(多一次矩阵乘 + 非线性)。

dot-product attention 要更快——因为它就是一个 matmul——但代价是必须手工加 dk\sqrt{d_k} 来保稳定。

4.4 一个关于 norm 的细节

Vaswani 假设 qqkk 每一维方差是 11

实际模型里,这通过 LayerNorm 大致成立——LayerNorm 把每一层输出的 mean 归零、std 归一。

但有些层(比如 attention 输出)是 LayerNorm 之前还是之后?这就涉及 Pre-LN vs Post-LN 的选择。

Pre-LN(LayerNorm 在 sublayer 之前)让 qqkk 在进入 attention 时严格 normalized——dk\sqrt{d_k} 的假设最契合。

Post-LN(LayerNorm 在 sublayer 之后)让 qqkk 在进入 attention 时未必 normalized——可能需要 warmup 来稳住训练。

到 2026 年,Pre-LN 是主流(GPT、LLaMA 都用 Pre-LN)。

4.5 那 Q、K 不是单位方差怎么办

如果 WqW_qWkW_k 初始化合理(比如 Xavier 或 Kaiming),且输入 XX 经过 LayerNorm,那么 q=Wqxq = W_q x 的方差大致就是 11

但如果你不做 LayerNorm、用奇怪初始化、或者训练到某一步参数漂移——方差就不是 1 了。

QK-norm(在 q、k 上做 LayerNorm)就是把这个假设显式强制——不再依靠「希望 LayerNorm 保住」。


五、点积方差的可视化

点积分布转存失败,建议直接上传图片文件

5.1 三个直方图

dk=8d_k = 8 时分布窄,σ2.83\sigma \approx 2.83

dk=64d_k = 64 时分布宽,σ=8\sigma = 8

dk=512d_k = 512 时分布很宽,σ22.6\sigma \approx 22.6

直方图的横轴是点积值——纵轴是出现频率。

5.2 为什么这个图重要

看到这张图,你应该马上意识到:

不缩放时,点积尺度完全由维度决定——你换一个模型规模,点积尺度就变了——你的训练超参(学习率、初始化等)就要重调。

缩放后,点积尺度永远是 1——超参可以跨规模迁移

这是 scaling laws 能成立的一个隐性前提:架构内的统计尺度必须不依赖于规模

5.3 Chinchilla scaling 的隐含条件

Hoffmann 2022 给出 Chinchilla 定律:参数 NpN_p 与 token 数 DD 的最优比例 NpD/20N_p \approx D/20

这条定律的成立依赖于「同样的架构、同样的训练超参在不同规模下都能稳定训练」。

如果你不缩放点积,训练在 dk=64d_k = 64 时还稳定,到 dk=512d_k = 512 时就发散——scaling laws 整个就不成立。

dk\sqrt{d_k} 是 scaling laws 的「隐性基础设施」之一。


六、训练曲线对比

训练曲线转存失败,建议直接上传图片文件

6.1 定性差异

红线(unscaled):早期 loss 下降慢,很快卡在某个高位——softmax 饱和导致的优化困难。

绿线(scaled):稳定下降。

初值相同(损失大约是 log(N)\log(N) 那个均匀分布的 cross-entropy)。

6.2 一个真实的定量例子

Vaswani 2017 §3.2.1 没给具体的对比训练曲线(论文较老),但后续工作(Xiong 2020)做过实验。

dk=64d_k = 64 的 Transformer-base 上,去掉 dk\sqrt{d_k}

  • 不调任何其它超参:loss 卡在 6.x,几乎不动。
  • 把学习率降低 10×10\times:训练能进行,但 BLEU 显著低于带 dk\sqrt{d_k} 的版本。

也就是说,没有 dk\sqrt{d_k} 不是「完全不能训练」,而是「需要付出极大的超参代价、且最终质量更差」。

加上 dk\sqrt{d_k} 等价于一个免费的、零计算开销的稳定性优化——为什么不用呢?

6.3 一个反直觉发现

有人发现:如果模型足够小(dk=16d_k = 16 之类),不缩放也能训

这与第二节的方差分析一致——σ=4\sigma = 4,logits 不会饱和。

但你不可能因为「小模型不需要」就在大模型里也省掉它——大模型里这个除号是必需品。


七、缩放与梯度下降稳定性

7.1 学习率与梯度的关系

如果不缩放,attention 的 logits 在 ±dk\pm \sqrt{d_k} 量级——softmax 输出近似 one-hot——梯度近似 00——参数几乎不更新。

但偶尔某个 batch 有「比较平的 logits」,梯度突然爆发——参数大跳——loss 飞涨。

这是「饱和 + 偶尔不饱和」的混合模式——非常不稳定。

7.2 缩放后的 Lipschitz 性质

缩放后,softmax 的输入永远在 [3σ,3σ]=[3,3][-3\sigma, 3\sigma] = [-3, 3] 左右——softmax 在这个区间内是 Lipschitz 连续的,导数有 bounded 上限。

这意味着「同样大小的输入扰动 \to 同样大小的输出扰动」——训练动力学是稳定的。


八、参考资料

  • Vaswani 2017: Ashish Vaswani et al., "Attention Is All You Need" (首次提出 Transformer 与 Scaled Dot-Product).
  • Bahdanau 2014: Dzmitry Bahdanau et al., "Neural Machine Translation by Jointly Learning to Align and Translate" (提出 Additive Attention, tanh\tanh 限制方差).
  • Luong 2015: Minh-Thang Luong et al., "Effective Approaches to Attention-based Neural Machine Translation" (提出 Dot-Product Attention).
  • Xiong 2020: Ruibin Xiong et al., "On Layer Normalization in the Transformer Architecture" (分析缩放机制与 LayerNorm 对训练稳定性的深度影响).
  • Hoffmann 2022: Jordan Hoffmann et al., "Training Compute-Optimal Large Language Models" (Chinchilla 定律,隐含架构缩放的统计不变性假设).

Lipschitz 常数大致是 11(用 \ell_{\infty} 范数估计)。

7.3 与梯度裁剪的关系

很多 Transformer 训练里都有「gradient clipping」(梯度裁剪),把过大的梯度截断到 gc\lVert g \rVert \le c

为什么需要梯度裁剪?因为偶尔会有「outlier batch」让某些参数的梯度爆掉——比如某个 batch 里所有 token 都是同一个。

scaled dot-product 让这种 outlier 的破坏力降低——但不能完全消除——所以梯度裁剪仍是必需。

7.4 与 warmup 的关系

Transformer 训练几乎都用 learning rate warmup(前若干步学习率从 0 线性涨到峰值)。

为什么?因为训练初期参数随机,logits 分布可能极不平衡——warmup 给模型时间「找到稳定区域」再放学习率。

scaled dot-product 让初期 logits 不那么大——warmup 期可以更短——但不能省略。


八、与 NTK / 无限宽神经网络理论的联系

8.1 NTK 是什么

NTK(Neural Tangent Kernel,Jacot 2018)是一种刻画「无限宽网络在小学习率下的训练动力学」的理论。

核心结论是:在某些假设下,无限宽网络的训练等价于一个线性化模型 + 核回归,其中的「核」就叫 NTK。

NTK 给我们一个工具:预测网络在不同初始化、不同尺度下的行为。

8.2 dk\sqrt{d_k} 与 NTK

NTK 理论强调一个 principle:网络中每一层的输入与输出的统计尺度必须一致——否则梯度传播会失衡。

scaled dot-product 正是这个 principle 在 attention 层的体现——把点积归一化到 σ=1\sigma = 1,让 attention 层的「输入尺度」与「输出尺度」一致。

如果不缩放,attention 层把方差从 11 放大到 dk\sqrt{d_k}——下一层 LayerNorm 又把它拉回 11——但中间这一段不稳定。

8.3 muP(Maximal Update Parametrization)

Yang & Hu 2021 的 muP 是 NTK 思想在工程上的实现:通过精心设计每层的 init scale 和 LR scale,让模型在改变宽度时超参不变

muP 框架下,attention 的 dk\sqrt{d_k} 是一个特殊处理——它不是 muP 自动推出的,而是早就独立存在的设计——但它与 muP 的精神高度契合。

到 2026 年,muP(特别是 mup-transfer 思路)成为大模型训练前调超参的重要工具——基础假设之一就是 attention 已经被 dk\sqrt{d_k} 正则化过了。

8.4 NTK 视角的 attention

在 NTK 视角下:

  • attention(Q,K,V)=softmax(QKTdk)V\operatorname{attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right) V 是一个 bilinear 算子。
  • bilinear 部分 QKTQK^{\mathsf{T}} 把维度从 dkd_k 映射到 N×NN \times N 的相似度矩阵。
  • softmax 是一个非线性归一化。
  • VV 的乘法把 N×NN \times N 映射回 N×dvN \times d_v

每一步的统计尺度都需要被控制——dk\sqrt{d_k} 是控制 QKTQK^{\mathsf{T}} 这一步的尺度的工具。

V 那边没有显式缩放,因为 softmax 输出已经是概率(行和为 1)——V 的均值和方差只取决于 V 自己的统计——这一步通常不需要额外正则化。


九、Vaswani 论文里的原话

9.1 § 3.2.1 的关键段落

原文(NeurIPS 2017):

"We suspect that for large values of d_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by 1/√d_k."

翻译:「我们怀疑对于大的 dkd_k,点积量级变大,从而把 softmax 推入梯度极小的区域。为抵消这个效应,我们用 1/dk1 / \sqrt{d_k} 缩放点积。」

9.2 这段话其实没给完整证明

注意「我们怀疑」(we suspect)——Vaswani 没有给出 Var(qk)=dk\operatorname{Var}(q \cdot k) = d_k 的形式化推导,也没有大规模消融实验来证明。

后来的工作(Xiong 2020, On Layer Normalization in the Transformer Architecture)才把这件事详细分析。

但工程上,Vaswani 的「直觉 + 简单理论」已经够用——大家用了 dk\sqrt{d_k},模型能训,事情就成立了。

这是科学研究里很常见的模式:实践先于理论——直觉推动实验,实验验证后再被理论补全。

9.3 注释:dot product vs scaled dot product 的对比实验

Vaswani 论文里的 Table 3 提到:

"While for small values of d_k the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of d_k. We suspect that..."

也就是说,Vaswani 团队做过对照实验——在 d_k 大时不缩放的 dot product 比加性 attention 差——所以加了缩放。

这是 dk\sqrt{d_k} 设计的直接动机。


十、对应的 PyTorch 实现

10.1 最朴素版

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)  # ← 这里就是 √d_k
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)
    return attn @ V, attn

注意 d_k ** 0.5 就是 dk\sqrt{d_k}

Q.size(-1) 自动取最后一维——所以这段代码不需要传 d_k 参数。

10.2 数值稳定版(log-sum-exp trick)

朴素 softmax 在大 logits 时可能溢出(exp(700) = inf)。

实践中 PyTorch 的 F.softmax 已经内置了 log-sum-exp trick——把所有 logits 减去 max 再 exp,结果不变但数值稳定。

这一点对 scaled dot-product 也有意义——因为缩放后 logits 仍可能在 ±10\pm 10 量级(在某些 head 学到 sharp pattern 时),log-sum-exp 仍是必要的。

10.3 PyTorch 2.0+ 的内置实现

out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)

这一行调用底层 CUDA / Metal 实现——可能是 FlashAttention,也可能是 Memory-Efficient Attention,由 backend 自动选择。

但「除以 dk\sqrt{d_k}」这件事仍然在背后发生——只是你不用手写。

10.4 FlashAttention 中的 dk\sqrt{d_k}

FlashAttention 的核心是「tile-by-tile 计算 softmax」——在 SRAM 内做 streaming softmax。

dk\sqrt{d_k} 缩放发生在每个 tile 计算 QKTQK^{\mathsf{T}} 的瞬间——和朴素实现没有本质区别。

工程上的难点是「数值稳定的 streaming softmax」(要保持 running max 和 running sum_exp)——但 dk\sqrt{d_k} 这一步是简单乘法,不影响 FlashAttention 的核心算法。

10.5 一个常见 bug:缩放放在哪里

有些实现会写:

Q = Q / (d_k ** 0.5)  # 提前缩放 Q
scores = Q @ K.transpose(-2, -1)

这等价于在 QKTQK^{\mathsf{T}} 上除以 dk\sqrt{d_k}——结果一样,但计算更高效(少一次矩阵元素级除法)。

但要注意:如果 K 有特殊处理(比如 RoPE),缩放放在哪里可能影响 RoPE 的正确性——一般推荐放在 score 上,最稳。


十一、几个常见的变体与争议

11.1 dk\sqrt{d_k} vs dmodel\sqrt{d_{\mathrm{model}}}

有人混淆:

dmodeld_{\mathrm{model}} 是 token 嵌入维度(比如 512512)。

dkd_k 是每个 head 的 Q/KQ/K 维度(比如 6464,如果有 88 个 head)。

scaled dot-product 用的是 dk\sqrt{d_k}不是 dmodel\sqrt{d_{\mathrm{model}}}

为什么?因为方差推导中,加和的项数是 dkd_k——每个 head 的点积只涉及 dkd_k 维。

如果你写错成 dmodel\sqrt{d_{\mathrm{model}}}(除得太多),attention 会过于平缓——softmax 输出近似均匀——模型失去选择性。

11.2 1/dk1/d_k vs 1/dk1/\sqrt{d_k}

如果你看到某些代码或论文写「÷dk\div d_k」(不是 dk\sqrt{d_k}),那是错的——除非他们定义 dkd_k原 dk\sqrt{\text{原 } d_k}

把方差推导记牢:σ=dk\sigma = \sqrt{d_k},所以分母是 dk\sqrt{d_k} 而不是 dkd_k

11.3 学习的温度参数 vs 固定的 dk\sqrt{d_k}

有些工作(Shaped Attention、Stable Attention)把 dk\sqrt{d_k} 替换成可学习的 τ\tau,让模型自适应温度。

这通常需要额外的稳定化(比如把 τ\tau clamp 到 [dk/2,2dk][\sqrt{d_k}/2, 2\sqrt{d_k}]),否则 τ\tau 容易学到 00\infty

主流大模型仍然用固定 dk\sqrt{d_k}——因为它已经够好了,省掉一个超参。

11.4 logit-cap:另一个稳定性技巧

Gemini 和某些 Anthropic 模型用「logit cap」技巧:

scores=ctanh(QKTdkc)\mathrm{scores} = c \cdot \tanh\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k} \, c}\right)

其中 cc 是某个 cap 值(比如 5050)。

这把 logits 强行 clip 到 [c,c][-c, c] 区间,防止极端 outlier。

这是 dk\sqrt{d_k} 之后的进一步加强——不替代它,而是补充它。

11.5 query 缩放还是 score 缩放

一些工程实现把缩放放在 QQ 上:

Q:=Qdk0.25,K:=Kdk0.25Q := \frac{Q}{d_k^{0.25}}, \qquad K := \frac{K}{d_k^{0.25}}

两个 0.250.25 次方相乘恰好得 dk\sqrt{d_k}

这种「分散到 Q 和 K」的写法在某些硬件上更高效,但数学上完全等价。

PyTorch 默认实现把缩放放在 score 上。


十二、与训练分布假设的关系

12.1 「q, k 是单位方差」这个假设有多严格

我们推导 Var(qk)=dk\operatorname{Var}(q \cdot k) = d_k 的关键假设是:每一维独立、零均值、单位方差。

实际中:

  • 独立:不严格,但近似成立(参数学到的 Q、K 投影把不同维度去相关到一定程度)。
  • 零均值:通过 LayerNorm 严格成立。
  • 单位方差:通过 LayerNorm 严格成立。

整体上,「点积方差 dk\approx d_k」是一个近似——但近似得相当好。

12.2 训练后期的偏移

训练后期,QQKK 的分布可能偏离单位方差(特别是某些 head 学到 sharp pattern 时)——logits 的实际方差可能比 dkd_k 小很多(因为 QQKK 学到了对齐方向,使 qkq \cdot k 偏正)。

这时 dk\sqrt{d_k} 给出的「过度缩放」让 softmax 仍然平缓——模型需要学习一个更大的 WqW_qWkW_k 来「弥补」缩放。

QK-norm 的提案就是为了解决这个:让 qqkk 在训练全程都保持单位方差,dk\sqrt{d_k} 缩放始终精确。

12.3 一个反直觉发现

Su 2024 等人发现:训练初期 logits 近似高斯,但训练到收敛时,logits 分布严重偏离高斯——出现一些「极端 outlier」(某些位置 logit 突变到 ±50\pm 50 以上)。

这种 outlier 对训练稳定性是灾难——logit-cap 就是为了 cap 这些 outlier。

dk\sqrt{d_k} 的「单位方差」假设在训练稳定期成立,但在收敛附近可能开始失效——这是一个开放研究方向。

12.4 高斯假设到底有多重要

有人会问:如果 Q、K 的分布不是高斯(而是 t 分布、混合高斯、甚至离散),方差推导还成立吗?

成立。

Var(X+Y)=Var(X)+Var(Y)\operatorname{Var}(X+Y) = \operatorname{Var}(X) + \operatorname{Var}(Y) 对任何独立随机变量都成立——和分布无关。

我们推 Var(qk)=dk\operatorname{Var}(q \cdot k) = d_k 时也只用了「独立 + 零均值 + 单位方差」——没有用到高斯假设。

预测(softmax 输出形状)要用 Central Limit Theorem——qkq \cdot kdkd_k 个项之和,dkd_k 大时近似高斯。

dk=64d_k = 64 已经足够 CLT 生效——分布看起来很高斯。

12.5 重尾分布(heavy tail)的影响

如果 QQKK 不是高斯而是 heavy tail(比如 tt 分布、Cauchy),那方差推导可能不成立——或者方差为 \infty

实际上深度网络的中间表示确实会有 heavy tail(参考 Martin & Mahoney 2018 关于深度网络中间表示的 heavy tail spectrum)。

但 LayerNorm + 标准初始化让 Q、K 的尾部不至于失控——这是工程上的「经验救场」。

到 2026 年,理解 QQKK 的真实分布、以及 dk\sqrt{d_k} 在 heavy tail 下的有效性,仍是开放问题。


十三、当 dk\sqrt{d_k} 不够用时

13.1 long-context 的 logits 暴涨

当上下文长度 N 很大时(比如 100k),同一个 query 要 attend 100k 个 key——每个 key 都贡献 logits 候选。

即使每个 logits 都是 σ=1\sigma = 1(缩放后),最大值 maxisi\max_i s_iNN 大时按 2lnN\sqrt{2 \ln N} 增长(极值统计)——logits 仍然漂移到大值。

softmax(s)\operatorname{softmax}(s) 中真正起作用的是 smaxss_{\max} - s——这个差仍然是 σ1\sigma \approx 1 的量级——所以 attention 仍然分布合理。

13.2 attention sink

Xiao 2023 (StreamingLLM) 发现:在 long-context 中,第一个 token(BOS)会「吸走」大量 attention 权重——这是 softmax 归一化的副产物。

具体机制:当所有 logits 都接近 0 时(没有特别匹配的 key),softmax 趋向均匀——但每个 token 都倾向把「无信息」的 attention 转移到「最早出现的 token」(BOS)。

dk\sqrt{d_k} 与 attention sink 没有直接关系——但在 long-context 中,dk\sqrt{d_k} 提供了基础稳定性,attention sink 现象在此基础上才能被研究。

13.3 ALiBi 与 dk\sqrt{d_k} 的相互作用

ALiBi(Press 2021)在 logits 上加一个负的距离偏置:

sij=qikjdkmijs_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}} - m \cdot |i - j|

mm 是一个固定的负向斜率。

ALiBi 与 dk\sqrt{d_k}叠加关系——dk\sqrt{d_k} 控制方差,ALiBi 控制位置偏置——两者各司其职。

13.4 RoPE 与 dk\sqrt{d_k} 的相互作用

RoPE(Su 2021)在 Q、K 上做旋转编码:

q=R(θ)q,k=R(θ)kq' = R(\theta) q, \qquad k' = R(\theta) k

其中 R(θ)R(\theta) 是旋转矩阵——保持向量长度不变。

因此 qkq' \cdot k' 的方差与 qkq \cdot k 一致——dk\sqrt{d_k} 缩放仍然有效。

RoPE 是「不影响 dk\sqrt{d_k} 假设」的位置编码——这是它能广泛应用的一个隐性原因。


十四、Muon 优化器与 dk\sqrt{d_k} 的现代视角

14.1 Muon 是什么

Muon(2024)是一个新型优化器,专为 Transformer 设计——它对 attention 的 WqW_qWkW_k 矩阵做特殊正交化。

核心思想:WqW_qWkW_k 在训练中容易变得「不正交」——这让 qkq \cdot k 的统计性质偏离原始假设——Muon 强制周期性正交化。

14.2 Muon 与 dk\sqrt{d_k} 的关系

Muon 维持 WqW_qWkW_k 的正交性 \to 维持 qqkk 的单位方差 \to 维持 dk\sqrt{d_k} 的精确性。

也就是说,Muon 让「dk\sqrt{d_k} 假设」在训练全程都接近精确——这反过来让 attention 训练更稳定。

到 2026 年,Muon 在某些大模型预训练中开始被采用(比如 Kimi 的 K2 模型)——这印证了「保护 dk\sqrt{d_k} 假设」的工程价值。

14.3 一种联合视角

dk\sqrt{d_k}、QK-norm、Muon、logit-cap 放到一起,你会发现一条主线:

保护 attention logits 的统计性质,让 softmax 始终在「有效梯度区」工作。

每一项技术都是这条主线的一个工具——dk\sqrt{d_k} 是最基础的、最便宜的、必须有的——其它都是渐进改进。


十五、一个完整的数值小例子

15.1 设置

dk=4d_k = 4qqkk 都是 44 维。

q=[1,0.5,0.5,1]q = [1, 0.5, -0.5, 1]
k1=[0.5,1,0,1]k_1 = [0.5, 1, 0, -1]
k2=[1,0.5,0.5,0]k_2 = [1, 0.5, 0.5, 0]
k3=[1,0,1,0.5]k_3 = [-1, 0, 1, 0.5]

15.2 不缩放的 logits

s1=qk1=0.5+0.5+01=0s_1 = q \cdot k_1 = 0.5 + 0.5 + 0 - 1 = 0
s2=qk2=1+0.250.25+0=1s_2 = q \cdot k_2 = 1 + 0.25 - 0.25 + 0 = 1
s3=qk3=1+00.5+0.5=1s_3 = q \cdot k_3 = -1 + 0 - 0.5 + 0.5 = -1
logits=[0,1,1]\mathrm{logits} = [0, 1, -1]
softmax([0,1,1])[0.244,0.665,0.090]\operatorname{softmax}([0, 1, -1]) \approx [0.244, 0.665, 0.090]

15.3 缩放的 logits

dk=2\sqrt{d_k} = 2
s1=0/2=0s'_1 = 0 / 2 = 0
s2=1/2=0.5s'_2 = 1 / 2 = 0.5
s3=1/2=0.5s'_3 = -1 / 2 = -0.5
logits=[0,0.5,0.5]\mathrm{logits}' = [0, 0.5, -0.5]
softmax([0,0.5,0.5])[0.295,0.487,0.218]\operatorname{softmax}([0, 0.5, -0.5]) \approx [0.295, 0.487, 0.218]

15.4 比较

不缩放:[0.244,0.665,0.090][0.244, 0.665, 0.090]——中间项更突出。

缩放:[0.295,0.487,0.218][0.295, 0.487, 0.218]——分布更平。

dk=4d_k = 4 这种小尺度下,差别不大——缩放只让 softmax 略缓和。

但当 dk=64d_k = 64 时,原始 logits 范围会扩大 44 倍(64/4=4\sqrt{64}/\sqrt{4} = 4),不缩放 logits 是 [0,4,4][0, 4, -4]softmax([0,4,4])[0.018,0.964,0.000]\operatorname{softmax}([0, 4, -4]) \approx [0.018, 0.964, 0.000]——几乎 one-hot!

而缩放后仍然是 [0,0.5,0.5][0, 0.5, -0.5]——分布合理。

这就是「维度越大,越需要 dk\sqrt{d_k}」的直观体现。

15.5 backward 梯度差异

设 loss 对 attention 输出有梯度信号 L/out\partial L / \partial \mathrm{out}

对未缩放 softmax:

Ls=JsoftmaxLout\frac{\partial L}{\partial s} = J_{\mathrm{softmax}} \cdot \frac{\partial L}{\partial \mathrm{out}}

JsoftmaxJ_{\mathrm{softmax}} 在「one-hot」状态下接近零——L/s0\partial L / \partial s \approx 0——s=QKTs = QK^{\mathsf{T}} 的梯度也接近零——QQKK 的更新极缓慢。

对缩放 softmax:

Ls=JsoftmaxLout\frac{\partial L}{\partial s'} = J_{\mathrm{softmax}} \cdot \frac{\partial L}{\partial \mathrm{out}}

这里 JsoftmaxJ_{\mathrm{softmax}} 不饱和,传播正常。

Ls=Ls1dk\frac{\partial L}{\partial s} = \frac{\partial L}{\partial s'} \cdot \frac{1}{\sqrt{d_k}}

注意还多了一个 1/dk1/\sqrt{d_k}——但这只让梯度等比缩小,不让它消失。

整体梯度量级仍然合理,训练能进行。


十六、关键概念回顾

  1. 点积方差qkq \cdot k 的方差等于 dkd_k(在标准假设下)。

  2. softmax 饱和:当 logits 量级远大于 1 时,softmax 输出近似 one-hot——梯度近似为零。

  3. dk\sqrt{d_k} 缩放:把 logits 方差归一化到 11,避免 softmax 饱和。

  4. 临界点:分母必须是 dk\sqrt{d_k}——除以 dkd_k 太多,除以更小则不够。

  5. scaling laws 隐性基础dk\sqrt{d_k} 让 attention 在不同维度下有相同的「统计工作点」——这是 Chinchilla scaling 能成立的前提之一。

  6. NTK 视角dk\sqrt{d_k} 是「保持每一层输入输出尺度一致」的具体实现。

  7. 现代变体:QK-norm、logit-cap、Muon 都是「保护 dk\sqrt{d_k} 假设」的延伸。

  8. 训练稳定性dk\sqrt{d_k} 不能完全替代 LayerNorm、warmup、gradient clipping——但它是这些手段的基础。

  9. Vaswani 原文:只是一句「we suspect」+ 简单实验——后续工作才补全理论。

  10. PyTorch 实现F.scaled_dot_product_attention 已内置 dk\sqrt{d_k}——但理解原理仍然重要。


十七、常见误解

17.1 dk\sqrt{d_k} 是经验技巧

错。这是一个有严格概率论推导的设计——不是「随便选一个数」。

17.4 dk\sqrt{d_k} 是 attention 唯一的稳定化机制

错。LayerNorm、warmup、gradient clipping、初始化都是稳定性的一部分——dk\sqrt{d_k} 只是其中一项。

17.5 缩放 = 退化

错。缩放后 attention 仍然能学到 sharp pattern——只是初始时不卡饱和——训练完成后该 sharp 还是 sharp。

17.12 缩放只为加速训练

不仅如此。

缩放更核心的目的是「让训练能进行」——而不是「让训练加快」。

不缩放时,训练在大模型上根本无法成功——加上缩放后训练才能稳定走完。

这是「质」的区别,不是「量」的区别。

17.14 dk\sqrt{d_k} 缩放破坏了 attention 的「概率含义」

不严格。

attention 输出仍然是「key 上的概率分布」(softmax 输出每行和为 11)——dk\sqrt{d_k} 只影响这个分布有多 sharp,不影响它是不是概率。

事实上,dk\sqrt{d_k} 让初始的概率分布「更接近均匀」——这反而是更好的概率初始化。

十八、下一步

到这里,我们已经把 attention 机制原理这一块的核心讲完了。

下一篇(第 16 篇)开始进入【Part 3:Transformer 架构】——讨论完整的 Transformer 块如何把 attention、FFN、residual、LayerNorm 串起来。

我们会从「2017 原始 Transformer」讲起,逐步看到「现代 LLaMA-style Transformer」演化的每一个改动是为什么——Pre-LN vs Post-LN、SwiGLU vs ReLU、RMSNorm vs LayerNorm 等。

如果你已经掌握了:

  • 第 11 篇的「attention 是什么」直觉
  • 第 12 篇的 Bahdanau 加性注意力
  • 第 13 篇的 Q/K/V 三件套
  • 第 14 篇的 self-attention 概念
  • 第 15 篇的 dk\sqrt{d_k} 缩放原理

那你已经有了进入 Transformer 架构层的所有理论基础——下一篇就把这些拼起来。


十九、参考文献

下面按相关度排序列出本篇直接引用与延伸阅读,每条附一句话提示其在本篇中的角色。

阅读建议:1、2、3、12 是核心,其余是延伸。

  1. Vaswani, A. et al. "Attention Is All You Need." NeurIPS 2017. §3.2.1 给出 √d_k 的最早动机。
  2. Xiong, R. et al. "On Layer Normalization in the Transformer Architecture." ICML 2020. 形式化分析 √d_k 与 Pre-LN 的关系。
  3. Luong, M.-T. et al. "Effective Approaches to Attention-based Neural Machine Translation." EMNLP 2015. dot-product attention 的经典工作,没有 √d_k——展示了不缩放的问题。
  4. Bahdanau, D. et al. "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR 2015. 加性 attention 没有 √d_k 问题,因为 tanh bound。
  5. Jacot, A., Gabriel, F., Hongler, C. "Neural Tangent Kernel: Convergence and Generalization in Neural Networks." NeurIPS 2018. NTK 理论的源头。
  6. Yang, G., Hu, E. J. "Tensor Programs IV: Feature Learning in Infinite-Width Neural Networks." ICML 2021. muP 的理论基础。
  7. Yang, G. et al. "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer." NeurIPS 2021. muP 的实操版,与 √d_k 互补。
  8. Hoffmann, J. et al. "Training Compute-Optimal Large Language Models." NeurIPS 2022. Chinchilla scaling laws,隐性依赖架构稳定性。
  9. Su, J. et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding." Neurocomputing 2024. RoPE 不破坏 √d_k 假设。
  10. Press, O., Smith, N. A., Lewis, M. "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation." ICLR 2022. ALiBi,与 √d_k 叠加使用。
  11. Xiao, G. et al. "Efficient Streaming Language Models with Attention Sinks." ICLR 2024. attention sink 现象。
  12. Dao, T. et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022. 工程实现里 √d_k 的位置。
  13. Henry, A. et al. "Query-Key Normalization for Transformers." EMNLP Findings 2020. QK-norm 提案。
  14. Shazeer, N. "Fast Transformer Decoding: One Write-Head is All You Need." arXiv 2019. MQA。
  15. Ainslie, J. et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023. GQA。
  16. Martin, C. H., Mahoney, M. W. "Implicit Self-Regularization in Deep Neural Networks." JMLR 2021. 深度网络中间表示的 heavy tail 现象。
  17. Jordan, K. et al. "Muon: An Optimizer for the Hidden Layers of Neural Networks." 2024 blog/preprint. Muon 的提出。

← 上一篇:14|Self-Attention | 下一篇:16|Multi-Head Attention