第01章:为什么是Attention?——从RNN的梯度瓶颈到Self-Attention的常数量路径
论文链接:Attention Is All You Need (Vaswani et al., NIPS 2017)
核心困惑
为什么Transformer要完全抛弃RNN和CNN,仅仅依靠Attention机制?
这不是一个小改进,而是架构上的彻底革命。2017年之前,序列建模的标配是RNN(或其变体LSTM/GRU)。Transformer直接把它们全扔了,只用Attention。这个决定背后的数学动机是什么?
前置知识补给站
1. 序列建模的本质
序列建模要解决的核心问题:给定输入序列(x1,x2,...,xn),如何建模序列中任意两个位置之间的依赖关系?
例如在机器翻译中:
- 输入:”The cat sat on the mat”
- 输出:”猫坐在垫子上”
“猫”这个词的翻译需要依赖”The cat”,而”坐在”需要依赖”sat on”。这些依赖关系可能跨越很长的距离。
2. 梯度反向传播的链式法则
在深度神经网络中,梯度通过链式法则反向传播:
∂L∂x1=∂L∂xn⋅∂xn∂xn−1⋅...⋅∂x2∂x1
如果路径很长(从位置1到位置n),梯度需要连乘很多次。这就是梯度消失/爆炸的根源。
3. 计算复杂度的表示
- O(n):线性复杂度,处理n个元素需要n步
- O(n²):平方复杂度,处理n个元素需要n²步
- O(1):常数复杂度,无论n多大,都只需要固定步数
论文精读:RNN的三大瓶颈
瓶颈1:顺序计算的诅咒
原论文Section 1:
“Recurrent models typically factor computation along the symbol positions of the input and output sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden states , as a function of the previous hidden state and the input for position . This inherently sequential nature precludes parallelization within training examples, which becomes critical at longer sequence lengths, as memory constraints limit batching across examples.”
翻译成人话:RNN必须按顺序计算。要算,必须先算出。这意味着:
- 位置1的计算完成后,才能开始位置2
- 位置2的计算完成后,才能开始位置3
- …
数学表达:
这个递归关系导致无法并行化。在GPU上,这是致命的效率问题。
瓶颈2:长距离依赖的梯度衰减
第一性原理推导:BPTT的梯度连乘
假设一个简单的RNN:
反向传播时,从位置到位置的梯度为:
其中:
关键问题:这是一个连乘。如果,梯度爆炸;如果,梯度消失。
数值示例:
- 假设(略小于1)
- 经过10步:
- 经过50步:
- 经过100步:
梯度几乎完全消失了。这意味着位置1的信息很难传递到位置100。
瓶颈3:最大路径长度是O(n)
原论文Section 4, Table 1:
在RNN中,位置1的信息要传递到位置n,必须经过步:
这个路径长度是O(n)。路径越长,信息衰减越严重。
LSTM的”缓解但未解决”
LSTM通过门控机制缓解了梯度消失问题:
关键改进:细胞状态的更新是加法而非乘法:
梯度反向传播时:
如果遗忘门,梯度可以几乎无损地传播。
注:严格来说,也依赖于(通过),因此完整梯度包含更多项。但这一项是直接的乘法路径,不受激活函数导数的影响,这是LSTM缓解梯度消失的关键。
但LSTM没有解决的问题:
- 顺序计算:仍然需要按顺序计算
- 路径长度:位置1到位置n的路径长度仍然是O(n)
- 门控的局限:遗忘门需要”学习”何时保留信息,这本身就很难
Self-Attention的革命性突破
突破1:并行计算
Self-Attention的核心公式:
关键特性:所有位置的attention可以同时计算。
- 计算:一次矩阵乘法,所有位置的相似度同时得出
- 计算softmax:逐行操作,可以并行
- 计算加权和:一次矩阵乘法,所有位置的输出同时得出
顺序操作数:O(1)(原论文Table 1)
突破2:常数路径长度
在Self-Attention中,任意两个位置之间的信息传递是直接的:
位置1可以直接attend到位置n,不需要经过中间的位置2, 3, …, n-1。
最大路径长度:O(1)(原论文Table 1)
这意味着:
- 梯度反向传播时,不需要连乘n次
- 长距离依赖可以直接建模,不会衰减
突破3:动态权重
RNN的权重矩阵是固定的,对所有位置都一样。
Self-Attention的权重是动态计算的:
每对位置的权重都是根据它们的内容(和)动态计算的。
消融实验解读:Table 1的复杂度对比
原论文Section 4, Table 1:
| Layer Type | Complexity per Layer | Sequential Operations | Maximum Path Length |
|---|---|---|---|
| Self-Attention | O(n^2 \cdot d) | O(1) | O(1) |
| Recurrent | O(n \cdot d^2) | O(n) | O(n) |
| Convolutional | O(k \cdot n \cdot d^2) | O(1) | O(\log_k(n)) |
| Self-Attention (restricted) | O(r \cdot n \cdot d) | O(1) | O(n/r) |
解读:
- 复杂度对比:
-
Self-Attention:,当时比RNN的更快
-
在实践中,序列长度通常小于模型维度(例如)
- 顺序操作:
-
Self-Attention:,完全并行
-
RNN:,必须顺序计算
- 最大路径长度:
-
Self-Attention:,任意两个位置直接连接
-
RNN:,需要经过所有中间位置
-
Convolutional:,需要堆叠多层
为什么Self-Attention胜出:
- 在序列长度的情况下(大多数NLP任务),Self-Attention的计算复杂度更低
- 完全并行化,充分利用GPU
- 常数路径长度,彻底解决长距离依赖问题
2026年的批判性视角
1. 复杂度的代价
原论文在2017年处理的序列长度通常是512-1024。但在2026年:
- GPT-4:128K tokens
- Claude 3:200K tokens
- Kimi:200K+ tokens
当时,的复杂度变成了瓶颈。这促使了一系列优化方案:
- Sparse Attention(Longformer, BigBird)
- Sliding Window Attention(Mistral)
- Flash Attention(优化内存访问)
- Attention Residuals(Kimi的方案,见第11章)
2. 并行化的隐藏成本
Self-Attention虽然可以并行计算,但需要存储完整的attention矩阵()。
内存占用:
- Attention矩阵:
- 当时:
这在2017年不是问题(时只需要1MB),但在2026年成为了主要瓶颈。
3. 位置编码的外推性
原论文提到sinusoidal位置编码”may allow the model to extrapolate to sequence lengths longer than the ones encountered during training”(Section 3.5)。但后续研究(Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, ICLR 2022)发现,sinusoidal编码在长度外推上的表现不如ALiBi等专门设计的方案。这促使了:
- RoPE(LLaMA, GPT-NeoX)
- ALiBi(BLOOM)
- 动态位置编码(各种变体)
4. 原论文没有讨论的问题
- 推理效率:原论文只关注训练,没有讨论KV Cache等推理优化
- 长文本能力:原论文在WMT翻译任务上验证,序列长度较短
- 多模态扩展:原论文只处理文本,但Self-Attention的思想后来被扩展到视觉、语音等领域
面试追问清单
⭐ 基础必会
- 为什么RNN会有梯度消失问题?用数学公式证明。
-
提示:从BPTT的梯度连乘推导
- LSTM如何缓解梯度消失?为什么说是”缓解”而非”解决”?
-
提示:细胞状态的加法更新 vs 顺序计算的本质
- Self-Attention的最大路径长度为什么是O(1)?
-
提示:从计算图的角度理解
⭐⭐ 进阶思考
- 在什么情况下RNN的复杂度比Self-Attention的更优?
-
提示:当时
- 为什么原论文说”Self-Attention可以完全替代RNN”,但后续研究又提出了各种混合架构(如Transformer-XL)?
-
提示:长文本、推理效率、归纳偏置
- 如果让你设计一个新的序列建模架构,你会如何平衡并行性、复杂度和长距离依赖建模能力?
-
提示:这是一个开放性问题,考察架构设计的权衡思维
⭐⭐⭐ 专家领域
- 证明:在Self-Attention中,任意两个位置之间的梯度传播路径长度是O(1)。
-
提示:从反向传播的计算图出发,分析梯度如何从输出传播到输入
- 原论文Table 1中的”Self-Attention (restricted)“是什么?为什么它的最大路径长度是?
-
提示:这是局部attention的变体,每个位置只attend到半径内的位置
- 如何从信息论的角度理解Self-Attention相比RNN的优势?
-
提示:互信息、信息瓶颈理论
下一章预告:第02章将展开Transformer的完整架构图,理解Encoder、Decoder以及三种Attention(Self-Attention、Masked Self-Attention、Cross-Attention)的数据流。
论文原文传送门:
- Transformer原论文:arxiv.org/abs/1706.03…
- 官方代码:github.com/tensorflow/…