第01章:为什么是Attention?——从RNN的梯度瓶颈到Self-Attention的常数量路径

0 阅读8分钟

第01章:为什么是Attention?——从RNN的梯度瓶颈到Self-Attention的常数量路径

论文链接:Attention Is All You Need (Vaswani et al., NIPS 2017)

核心困惑

为什么Transformer要完全抛弃RNN和CNN,仅仅依靠Attention机制?

这不是一个小改进,而是架构上的彻底革命。2017年之前,序列建模的标配是RNN(或其变体LSTM/GRU)。Transformer直接把它们全扔了,只用Attention。这个决定背后的数学动机是什么?

前置知识补给站

1. 序列建模的本质

序列建模要解决的核心问题:给定输入序列(x1,x2,...,xn),如何建模序列中任意两个位置之间的依赖关系?

例如在机器翻译中:

  • 输入:”The cat sat on the mat”
  • 输出:”猫坐在垫子上”

“猫”这个词的翻译需要依赖”The cat”,而”坐在”需要依赖”sat on”。这些依赖关系可能跨越很长的距离。

2. 梯度反向传播的链式法则

在深度神经网络中,梯度通过链式法则反向传播:

∂L∂x1=∂L∂xn⋅∂xn∂xn−1⋅...⋅∂x2∂x1

如果路径很长(从位置1到位置n),梯度需要连乘很多次。这就是梯度消失/爆炸的根源。

3. 计算复杂度的表示

  • O(n):线性复杂度,处理n个元素需要n步
  • O(n²):平方复杂度,处理n个元素需要n²步
  • O(1):常数复杂度,无论n多大,都只需要固定步数

论文精读:RNN的三大瓶颈

瓶颈1:顺序计算的诅咒

原论文Section 1:

“Recurrent models typically factor computation along the symbol positions of the input and output sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden states , as a function of the previous hidden state and the input for position . This inherently sequential nature precludes parallelization within training examples, which becomes critical at longer sequence lengths, as memory constraints limit batching across examples.”

翻译成人话:RNN必须按顺序计算。要算,必须先算出。这意味着:

  • 位置1的计算完成后,才能开始位置2
  • 位置2的计算完成后,才能开始位置3

数学表达:

这个递归关系导致无法并行化。在GPU上,这是致命的效率问题。

瓶颈2:长距离依赖的梯度衰减

第一性原理推导:BPTT的梯度连乘

假设一个简单的RNN:

反向传播时,从位置到位置的梯度为:

其中:

关键问题:这是一个连乘。如果,梯度爆炸;如果,梯度消失。

数值示例:

  • 假设(略小于1)
  • 经过10步:
  • 经过50步:
  • 经过100步:

梯度几乎完全消失了。这意味着位置1的信息很难传递到位置100。

瓶颈3:最大路径长度是O(n)

原论文Section 4, Table 1:

在RNN中,位置1的信息要传递到位置n,必须经过步:

这个路径长度是O(n)。路径越长,信息衰减越严重。

LSTM的”缓解但未解决”

LSTM通过门控机制缓解了梯度消失问题:

关键改进:细胞状态的更新是加法而非乘法:

梯度反向传播时:

如果遗忘门,梯度可以几乎无损地传播。

注:严格来说,也依赖于(通过),因此完整梯度包含更多项。但这一项是直接的乘法路径,不受激活函数导数的影响,这是LSTM缓解梯度消失的关键。

但LSTM没有解决的问题:

  1. 顺序计算:仍然需要按顺序计算
  2. 路径长度:位置1到位置n的路径长度仍然是O(n)
  3. 门控的局限:遗忘门需要”学习”何时保留信息,这本身就很难

Self-Attention的革命性突破

突破1:并行计算

Self-Attention的核心公式:

关键特性:所有位置的attention可以同时计算。

  • 计算:一次矩阵乘法,所有位置的相似度同时得出
  • 计算softmax:逐行操作,可以并行
  • 计算加权和:一次矩阵乘法,所有位置的输出同时得出

顺序操作数:O(1)(原论文Table 1)

突破2:常数路径长度

在Self-Attention中,任意两个位置之间的信息传递是直接的:

位置1可以直接attend到位置n,不需要经过中间的位置2, 3, …, n-1。

最大路径长度:O(1)(原论文Table 1)

这意味着:

  • 梯度反向传播时,不需要连乘n次
  • 长距离依赖可以直接建模,不会衰减

突破3:动态权重

RNN的权重矩阵是固定的,对所有位置都一样。

Self-Attention的权重是动态计算的:

每对位置的权重都是根据它们的内容(和)动态计算的。

消融实验解读:Table 1的复杂度对比

原论文Section 4, Table 1:

Layer TypeComplexity per LayerSequential OperationsMaximum Path Length
Self-AttentionO(n^2 \cdot d)O(1)O(1)
RecurrentO(n \cdot d^2)O(n)O(n)
ConvolutionalO(k \cdot n \cdot d^2)O(1)O(\log_k(n))
Self-Attention (restricted)O(r \cdot n \cdot d)O(1)O(n/r)

解读:

  1. 复杂度对比:
  • Self-Attention:,当时比RNN的更快

  • 在实践中,序列长度通常小于模型维度(例如)

  1. 顺序操作:
  • Self-Attention:,完全并行

  • RNN:,必须顺序计算

  1. 最大路径长度:
  • Self-Attention:,任意两个位置直接连接

  • RNN:,需要经过所有中间位置

  • Convolutional:,需要堆叠多层

为什么Self-Attention胜出:

  • 在序列长度的情况下(大多数NLP任务),Self-Attention的计算复杂度更低
  • 完全并行化,充分利用GPU
  • 常数路径长度,彻底解决长距离依赖问题

2026年的批判性视角

1. 复杂度的代价

原论文在2017年处理的序列长度通常是512-1024。但在2026年:

  • GPT-4:128K tokens
  • Claude 3:200K tokens
  • Kimi:200K+ tokens

当时,的复杂度变成了瓶颈。这促使了一系列优化方案:

  • Sparse Attention(Longformer, BigBird)
  • Sliding Window Attention(Mistral)
  • Flash Attention(优化内存访问)
  • Attention Residuals(Kimi的方案,见第11章)

2. 并行化的隐藏成本

Self-Attention虽然可以并行计算,但需要存储完整的attention矩阵()。

内存占用:

  • Attention矩阵:
  • 当时:

这在2017年不是问题(时只需要1MB),但在2026年成为了主要瓶颈。

3. 位置编码的外推性

原论文提到sinusoidal位置编码”may allow the model to extrapolate to sequence lengths longer than the ones encountered during training”(Section 3.5)。但后续研究(Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, ICLR 2022)发现,sinusoidal编码在长度外推上的表现不如ALiBi等专门设计的方案。这促使了:

  • RoPE(LLaMA, GPT-NeoX)
  • ALiBi(BLOOM)
  • 动态位置编码(各种变体)

4. 原论文没有讨论的问题

  • 推理效率:原论文只关注训练,没有讨论KV Cache等推理优化
  • 长文本能力:原论文在WMT翻译任务上验证,序列长度较短
  • 多模态扩展:原论文只处理文本,但Self-Attention的思想后来被扩展到视觉、语音等领域

面试追问清单

⭐ 基础必会

  1. 为什么RNN会有梯度消失问题?用数学公式证明。
  • 提示:从BPTT的梯度连乘推导

  1. LSTM如何缓解梯度消失?为什么说是”缓解”而非”解决”?
  • 提示:细胞状态的加法更新 vs 顺序计算的本质

  1. Self-Attention的最大路径长度为什么是O(1)?
  • 提示:从计算图的角度理解

⭐⭐ 进阶思考

  1. 在什么情况下RNN的复杂度比Self-Attention的更优?
  • 提示:当时

  1. 为什么原论文说”Self-Attention可以完全替代RNN”,但后续研究又提出了各种混合架构(如Transformer-XL)?
  • 提示:长文本、推理效率、归纳偏置

  1. 如果让你设计一个新的序列建模架构,你会如何平衡并行性、复杂度和长距离依赖建模能力?
  • 提示:这是一个开放性问题,考察架构设计的权衡思维

⭐⭐⭐ 专家领域

  1. 证明:在Self-Attention中,任意两个位置之间的梯度传播路径长度是O(1)。
  • 提示:从反向传播的计算图出发,分析梯度如何从输出传播到输入

  1. 原论文Table 1中的”Self-Attention (restricted)“是什么?为什么它的最大路径长度是?
  • 提示:这是局部attention的变体,每个位置只attend到半径内的位置

  1. 如何从信息论的角度理解Self-Attention相比RNN的优势?
  • 提示:互信息、信息瓶颈理论


下一章预告:第02章将展开Transformer的完整架构图,理解Encoder、Decoder以及三种Attention(Self-Attention、Masked Self-Attention、Cross-Attention)的数据流。

论文原文传送门: