Transformer 架构深度解析:从原理到公式推导
本文将带你从零开始,彻底理解 Transformer 的工作原理和核心公式。无论你是初学者还是希望深入理解细节的开发者,都能从中获益。
一、开篇:Transformer 为何如此重要?
2017 年,Google 团队发表了论文《Attention Is All You Need》,提出了 Transformer 架构。这篇论文彻底改变了自然语言处理(NLP)领域,并逐渐影响到计算机视觉、语音等几乎所有 AI 领域。
Transformer 的"战绩":
- GPT 系列:ChatGPT 的底层架构
- BERT:开创预训练语言模型时代
- ViT:将 Transformer 应用到图像识别
- Stable Diffusion:AI 绘画的核心组件
可以说,没有 Transformer,就没有今天的 AI 大模型时代。
二、Transformer 解决了什么问题?
2.1 传统序列模型的困境
在 Transformer 之前,处理序列数据(如文本、语音)主要依赖 RNN(循环神经网络)和 LSTM。
RNN 的工作方式:
"我 爱 中 国"
我 → [RNN] → h₁
↓
爱 → [RNN] → h₂
↓
中 → [RNN] → h₃
↓
国 → [RNN] → h₄
问题 1:无法并行计算
RNN 必须按顺序处理每个词,后一个词必须等前一个词处理完。这导致:
- 训练速度慢
- 无法充分利用 GPU 的并行能力
问题 2:长距离依赖困难
当句子很长时,早期词的信息在传递过程中会逐渐"遗忘"。比如:
"那个我在北京上大学时认识的朋友,现在在那个城市工作。"
要理解"那个城市"指的是"北京",信息需要跨越很多词传递,RNN 很难做到。
2.2 Transformer 的解决方案
| 问题 | RNN 的困境 | Transformer 的解决 |
|---|---|---|
| 并行计算 | 必须顺序处理 | 所有位置同时计算 |
| 长距离依赖 | 信息逐渐衰减 | 注意力机制直接关联任意两个位置 |
| 计算效率 | O(n) 时间,无法并行 | O(1) 层数,高度并行 |
核心思想:用注意力机制(Attention)替代循环结构。
三、注意力机制:Transformer 的核心
3.1 什么是注意力?
想象你在阅读一句话:
"小明把苹果放在桌上,然后吃掉了它。"
当你理解"吃掉了它"时,大脑会自动"注意"到"苹果"这个词,而不是"桌上"或"小明"。
这就是注意力机制:让模型学会"该关注哪些信息"。
3.2 注意力的数学直觉
注意力本质上是一种加权求和:
输出 = 权重₁ × 值₁ + 权重₂ × 值₂ + 权重₃ × 值₃ + ...
权重越大,表示对应的信息越重要。关键问题是:如何计算这些权重?
3.3 Query、Key、Value:注意力的三要素
Transformer 引入了三个核心概念:
| 概念 | 含义 | 类比 |
|---|---|---|
| Query (Q) | 查询:我想找什么? | 你在搜索引擎输入的关键词 |
| Key (K) | 键:这个东西是什么? | 网页的标题/描述 |
| Value (V) | 值:这个东西的内容 | 网页的实际内容 |
搜索引擎类比:
- 你输入搜索关键词(Query)
- 搜索引擎用关键词和网页标题(Key)做匹配
- 匹配度高的网页(Value)排在前面
- 你看到的是加权后的结果
四、自注意力机制详解与公式推导
4.1 自注意力(Self-Attention)是什么?
"自注意力"中的"自"表示:序列中的每个位置都和序列中的所有位置(包括自己)计算注意力。
输入序列:[我, 爱, 中, 国]
我 爱 中 国
┌──────────────────────┐
我 → │ 0.4 0.3 0.1 0.2 │ → 我的新表示
爱 → │ 0.2 0.5 0.2 0.1 │ → 爱的新表示
中 → │ 0.1 0.2 0.4 0.3 │ → 中的新表示
国 → │ 0.1 0.1 0.3 0.5 │ → 国的新表示
└──────────────────────┘
注意力权重矩阵
每个词都能"看到"其他所有词,并根据相关性分配权重。
4.2 公式推导:从输入到输出
让我们一步步推导自注意力的完整公式。
Step 1:准备输入
假设输入序列有 n 个词,每个词用 d 维向量表示:
输入矩阵 X ∈ ℝⁿˣᵈ
例如:4个词,每个词64维
X = [x₁] n=4
[x₂]
[x₃] d=64
[x₄]
Step 2:生成 Q、K、V
通过三个可学习的权重矩阵,将输入转换为 Query、Key、Value:
Q = X · Wq (Wq ∈ ℝᵈˣᵈᵏ)
K = X · Wk (Wk ∈ ℝᵈˣᵈᵏ)
V = X · Wv (Wv ∈ ℝᵈˣᵈᵛ)
其中:
- Wq、Wk、Wv 是可学习的参数矩阵
- dₖ 是 Key/Query 的维度
- dᵥ 是 Value 的维度(通常 dₖ = dᵥ = d)
为什么需要三个不同的矩阵?
这样 Q、K、V 可以从不同角度编码信息:
- Q:这个词"想要查找"什么
- K:这个词"对外呈现"什么
- V:这个词"实际包含"什么
Step 3:计算注意力分数
用 Query 和 Key 的点积衡量相关性:
分数矩阵 S = Q · Kᵀ
S ∈ ℝⁿˣⁿ
S[i][j] 表示第 i 个词对第 j 个词的"关注程度"。
点积的直觉:两个向量的点积越大,说明它们方向越一致,即越相似。
Step 4:缩放(Scaling)
S_scaled = S / √dₖ
为什么要除以 √dₖ? 这是关键的细节!
假设 Q 和 K 的每个元素都是均值为 0、方差为 1 的独立随机变量。那么点积:
q · k = Σᵢ qᵢ × kᵢ
这个和的方差 = dₖ(dₖ 个独立随机变量相加)
当 dₖ 很大时(比如 64 或 512),点积的值会非常大。这会导致 softmax 后的分布变得极端(接近 one-hot),梯度消失。
除以 √dₖ 可以让方差回到 1,保持数值稳定。
Step 5:Softmax 归一化
注意力权重 A = softmax(S_scaled)
对每一行做 softmax,确保权重和为 1:
A[i][j] = exp(S_scaled[i][j]) / Σₖ exp(S_scaled[i][k])
Step 6:加权求和得到输出
输出 = A · V
每个位置的输出是所有 Value 的加权和,权重就是注意力分数。
4.3 完整公式
将上述步骤整合,得到著名的缩放点积注意力公式:
┌────────────────────────────┐
│ QKᵀ │
Attention(Q,K,V) = │ softmax(────) · V │
│ √dₖ │
└────────────────────────────┘
用更规范的数学符号表示:
4.4 计算示例
假设我们有一个简化的例子:
输入:2个词,每个词3维
X = [1, 0, 1] ← 词1
[0, 1, 1] ← 词2
假设 Wq = Wk = Wv = I(单位矩阵,简化计算)
则 Q = K = V = X
Step 1: 计算 QKᵀ
QKᵀ = [1,0,1] [1, 0] = [1×1+0×0+1×1, 1×0+0×1+1×1] = [2, 1]
[0,1,1] [0, 1] [0×1+1×0+1×1, 0×0+1×1+1×1] [1, 2]
[1, 1]
Step 2: 缩放(dₖ=3)
S_scaled = [2, 1] / √3 ≈ [1.15, 0.58]
[1, 2] [0.58, 1.15]
Step 3: Softmax(按行)
第1行:exp(1.15)/(exp(1.15)+exp(0.58)) ≈ 0.64, 0.36
第2行:exp(0.58)/(exp(0.58)+exp(1.15)) ≈ 0.36, 0.64
A ≈ [0.64, 0.36]
[0.36, 0.64]
Step 4: 计算输出
输出 = A · V = [0.64, 0.36] × [1, 0, 1]
[0.36, 0.64] [0, 1, 1]
词1的新表示 = 0.64×[1,0,1] + 0.36×[0,1,1] = [0.64, 0.36, 1.0]
词2的新表示 = 0.36×[1,0,1] + 0.64×[0,1,1] = [0.36, 0.64, 1.0]
可以看到,每个词的新表示融合了其他所有词的信息。
五、多头注意力:并行的多视角关注
5.1 为什么需要多头注意力?
单个注意力头只能学习一种"关注模式"。但在理解语言时,我们需要同时关注多种关系:
| 关注角度 | 例子 |
|---|---|
| 语法结构 | 主语-谓语-宾语 |
| 指代关系 | "它"指代什么 |
| 语义相似 | 同义词、近义词 |
| 位置关系 | 相邻词的关联 |
多头注意力让模型同时学习多种注意力模式。
5.2 多头注意力的结构
┌─────────────┐
┌────►│ 注意力头 1 │────┐
│ └─────────────┘ │
│ ┌─────────────┐ │
输入 X ───────┼────►│ 注意力头 2 │────┼───► Concat ──► 线性变换 ──► 输出
│ └─────────────┘ │
│ ┌─────────────┐ │
└────►│ 注意力头 h │────┘
└─────────────┘
5.3 公式推导
Step 1:分头计算
对于第 i 个头:
headᵢ = Attention(Q·Wqⁱ, K·Wkⁱ, V·Wvⁱ)
其中:
- Wqⁱ, Wkⁱ ∈ ℝᵈˣ⁽ᵈ/ʰ⁾
- Wvⁱ ∈ ℝᵈˣ⁽ᵈ/ʰ⁾
- h 是头的数量
关键设计:每个头的维度是 d/h,所以 h 个头拼接后维度还是 d。
Step 2:拼接所有头
MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) · Wᵒ
其中 Wᵒ ∈ ℝᵈˣᵈ 是输出投影矩阵。
5.4 多头注意力的完整公式
5.5 一个直观的理解
把多头注意力想象成多个专家同时看同一个问题:
- 专家 1 关注语法结构
- 专家 2 关注指代关系
- 专家 3 关注情感倾向
- ...
最后把所有专家的意见综合起来,得到更全面的理解。
六、Transformer 的完整架构
6.1 整体结构
输入 输出
↓ ↓
[词嵌入 + 位置编码] [词嵌入 + 位置编码]
↓ ↓
┌─────────────┐ ┌─────────────┐
│ Encoder │ │ Decoder │
│ ×N 层 │ ─────────────────►│ ×N 层 │
└─────────────┘ └─────────────┘
↓
[线性层 + Softmax]
↓
预测输出
6.2 编码器(Encoder)结构
每个编码器层包含:
输入
↓
┌────────────┐
│ 多头自注意力 │◄────┐
└────────────┘ │ 残差连接
↓ │
Add & Norm ─────┘
↓
┌────────────┐
│ 前馈神经网络 │◄────┐
└────────────┘ │ 残差连接
↓ │
Add & Norm ─────┘
↓
输出
残差连接(Add):解决深层网络训练困难的问题(和 ResNet 思想一致)
层归一化(Norm):稳定训练,加速收敛
6.3 解码器(Decoder)结构
解码器比编码器多了一个交叉注意力层:
输入(已生成的部分)
↓
┌─────────────────┐
│ 带掩码的多头自注意力 │◄────┐
└─────────────────┘ │
↓ │
Add & Norm ──────────┘
↓
┌─────────────────┐
│ 交叉注意力 │◄────┐ ←── 接收编码器的输出
│ (Encoder-Decoder)│ │
└─────────────────┘ │
↓ │
Add & Norm ──────────┘
↓
┌────────────┐
│ 前馈神经网络 │◄────┐
└────────────┘ │
↓ │
Add & Norm ─────┘
↓
输出
带掩码的自注意力(Masked Self-Attention):
在生成时,模型只能看到已经生成的词,不能"偷看"未来的词。掩码实现这一限制:
掩码矩阵(上三角为-∞):
词1 词2 词3 词4
词1 [ 0 -∞ -∞ -∞ ]
词2 [ 0 0 -∞ -∞ ]
词3 [ 0 0 0 -∞ ]
词4 [ 0 0 0 0 ]
加到注意力分数上后,-∞ 经过 softmax 变成 0,即完全不关注未来的词。
七、位置编码:让模型知道顺序
7.1 为什么需要位置编码?
注意力机制是排列不变的:
Attention("我爱中国") = Attention("国中爱我") ?
如果不加处理,模型无法区分词的顺序!但语言中顺序至关重要。
7.2 位置编码的设计
Transformer 使用正弦/余弦位置编码:
PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
其中:
- pos:词在序列中的位置(0, 1, 2, ...)
- i:维度的索引(0, 1, 2, ..., d/2-1)
- d:编码维度
为什么用正弦/余弦?
- 有界性:值始终在 [-1, 1] 范围内
- 周期性:可以表达相对位置关系
- 可外推:理论上可以处理任意长度的序列
关键性质:对于固定的偏移量 k,PE(pos+k) 可以表示为 PE(pos) 的线性函数。这让模型容易学习相对位置关系。
7.3 位置编码的可视化
位置
↓
0 ━━━━━━━━━━━━━━━━━━━━
1 ━━━━━━━━━━━━━━━━━━━∙
2 ━━━━━━━━━━━━━━━━━━∙∙
3 ━━━━━━━━━━━━━━━━━∙∙∙
...
低频 ←────────────→ 高频
维度方向
- 低维度变化慢(表示粗粒度位置)
- 高维度变化快(表示细粒度位置)
八、前馈神经网络(FFN)
每个 Transformer 层还包含一个前馈神经网络:
FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
或者用更现代的 GELU 激活:
FFN(x) = GELU(xW₁ + b₁)W₂ + b₂
结构
输入(d 维)
↓
[线性层] → d → 4d(扩展)
↓
[激活函数 ReLU/GELU]
↓
[线性层] → 4d → d(收缩)
↓
输出(d 维)
FFN 的作用:
- 增加模型容量
- 引入非线性
- 每个位置独立处理(可并行)
为什么中间层要扩大 4 倍?
这是经验设计。扩大维度增加了模型的表达能力,是参数量和效果的权衡。
九、训练与推理
9.1 训练过程
任务:给定输入序列,预测下一个词
损失函数:交叉熵
Loss = -Σ log P(正确词 | 上下文)
优化技巧:
- 学习率预热(Warmup):开始时用小学习率,逐渐增大,然后再衰减
- Dropout:随机丢弃部分神经元,防止过拟合
- 标签平滑(Label Smoothing):软化目标分布,提高泛化能力
9.2 推理过程(生成)
自回归生成:
输入:"今天天气"
步骤1:模型预测下一个词 → "很"
步骤2:输入变成 "今天天气很",预测 → "好"
步骤3:输入变成 "今天天气很好",预测 → [结束符]
输出:"今天天气很好"
每一步都要重新计算,这就是为什么生成很慢。
KV 缓存(KV Cache):
为了加速,可以缓存已计算的 Key 和 Value,避免重复计算:
不用缓存:每步计算 O(n²),总共 O(n³)
使用缓存:每步计算 O(n),总共 O(n²)
十、Transformer 的变体与发展
10.1 主流架构
| 架构类型 | 代表模型 | 特点 |
|---|---|---|
| Encoder-only | BERT | 双向理解,适合分类、问答 |
| Decoder-only | GPT | 单向生成,适合文本生成 |
| Encoder-Decoder | T5, BART | 完整架构,适合翻译、摘要 |
10.2 效率优化
| 优化方法 | 思路 | 代表工作 |
|---|---|---|
| 稀疏注意力 | 只关注部分位置 | Longformer, BigBird |
| 线性注意力 | 将复杂度从 O(n²) 降到 O(n) | Linear Transformer |
| Flash Attention | 优化内存访问模式 | Flash Attention |
| 分组查询注意力 | 多个头共享 K/V | GQA |
10.3 架构改进
- RMSNorm:替代 LayerNorm,更高效
- 旋转位置编码(RoPE):替代绝对位置编码,外推能力更强
- SwiGLU 激活:替代 ReLU/GELU,效果更好
十一、总结
让我们回顾 Transformer 的核心要点:
| 组件 | 作用 | 创新点 |
|---|---|---|
| 自注意力 | 建立序列内任意位置的关系 | 并行计算、长距离建模 |
| 多头注意力 | 多视角关注不同模式 | 增加表达能力 |
| 位置编码 | 引入位置信息 | 正弦编码,可外推 |
| 残差连接 | 训练深层网络 | 梯度直通 |
| 层归一化 | 稳定训练 | 加速收敛 |
Transformer 成功的关键:
- 并行性:摆脱 RNN 的顺序依赖
- 长距离建模:注意力机制直接连接任意位置
- 可扩展性:架构简洁,易于堆叠更多层
核心公式速记:
# 自注意力
Attention(Q,K,V) = softmax(QKᵀ/√dₖ) × V
# 多头注意力
MultiHead(Q,K,V) = Concat(head₁,...,headₕ)Wᵒ
where headᵢ = Attention(QWᵢQ, KWᵢK, VWᵢV)
延伸阅读
如果你想进一步深入:
- 论文原文:Attention Is All You Need(Vaswani et al., 2017)
- 经典解读:The Illustrated Transformer(Jay Alammar)
- 代码实践:The Annotated Transformer(Harvard NLP)
- 进阶内容:FlashAttention、Rotary Position Embedding 等优化技术
后记:Transformer 的设计之美在于其简洁性。"Attention Is All You Need"这个标题本身就是一种宣言:用一个统一的机制(注意力)替代复杂的循环结构、门控机制,反而取得了更好的效果。这告诉我们,有时候化繁为简才是真正的创新。