Transformer 架构:从原理到公式推导

10 阅读12分钟

Transformer 架构深度解析:从原理到公式推导

本文将带你从零开始,彻底理解 Transformer 的工作原理和核心公式。无论你是初学者还是希望深入理解细节的开发者,都能从中获益。


一、开篇:Transformer 为何如此重要?

2017 年,Google 团队发表了论文《Attention Is All You Need》,提出了 Transformer 架构。这篇论文彻底改变了自然语言处理(NLP)领域,并逐渐影响到计算机视觉、语音等几乎所有 AI 领域。

Transformer 的"战绩"

  • GPT 系列:ChatGPT 的底层架构
  • BERT:开创预训练语言模型时代
  • ViT:将 Transformer 应用到图像识别
  • Stable Diffusion:AI 绘画的核心组件

可以说,没有 Transformer,就没有今天的 AI 大模型时代


二、Transformer 解决了什么问题?

2.1 传统序列模型的困境

在 Transformer 之前,处理序列数据(如文本、语音)主要依赖 RNN(循环神经网络)和 LSTM。

RNN 的工作方式

"我 爱 中 国"

  我 → [RNN] → h₁
              ↓
  爱 → [RNN] → h₂
              ↓
  中 → [RNN] → h₃
              ↓
  国 → [RNN] → h₄

问题 1:无法并行计算

RNN 必须按顺序处理每个词,后一个词必须等前一个词处理完。这导致:

  • 训练速度慢
  • 无法充分利用 GPU 的并行能力

问题 2:长距离依赖困难

当句子很长时,早期词的信息在传递过程中会逐渐"遗忘"。比如:

"那个我在北京上大学时认识的朋友,现在在那个城市工作。"

要理解"那个城市"指的是"北京",信息需要跨越很多词传递,RNN 很难做到。

2.2 Transformer 的解决方案

问题RNN 的困境Transformer 的解决
并行计算必须顺序处理所有位置同时计算
长距离依赖信息逐渐衰减注意力机制直接关联任意两个位置
计算效率O(n) 时间,无法并行O(1) 层数,高度并行

核心思想:用注意力机制(Attention)替代循环结构


三、注意力机制:Transformer 的核心

3.1 什么是注意力?

想象你在阅读一句话:

"小明把苹果放在桌上,然后吃掉了它。"

当你理解"吃掉了它"时,大脑会自动"注意"到"苹果"这个词,而不是"桌上"或"小明"。

这就是注意力机制:让模型学会"该关注哪些信息"。

3.2 注意力的数学直觉

注意力本质上是一种加权求和

输出 = 权重₁ × 值₁ + 权重₂ × 值₂ + 权重₃ × 值₃ + ...

权重越大,表示对应的信息越重要。关键问题是:如何计算这些权重?

3.3 Query、Key、Value:注意力的三要素

Transformer 引入了三个核心概念:

概念含义类比
Query (Q)查询:我想找什么?你在搜索引擎输入的关键词
Key (K)键:这个东西是什么?网页的标题/描述
Value (V)值:这个东西的内容网页的实际内容

搜索引擎类比

  1. 你输入搜索关键词(Query)
  2. 搜索引擎用关键词和网页标题(Key)做匹配
  3. 匹配度高的网页(Value)排在前面
  4. 你看到的是加权后的结果

四、自注意力机制详解与公式推导

4.1 自注意力(Self-Attention)是什么?

"自注意力"中的"自"表示:序列中的每个位置都和序列中的所有位置(包括自己)计算注意力

输入序列:[我, 爱, 中, 国]

          我    爱    中    国
       ┌──────────────────────┐
  我 → │ 0.4   0.3   0.1   0.2 │ → 我的新表示
  爱 → │ 0.2   0.5   0.2   0.1 │ → 爱的新表示
  中 → │ 0.1   0.2   0.4   0.3 │ → 中的新表示
  国 → │ 0.1   0.1   0.3   0.5 │ → 国的新表示
       └──────────────────────┘
              注意力权重矩阵

每个词都能"看到"其他所有词,并根据相关性分配权重。

4.2 公式推导:从输入到输出

让我们一步步推导自注意力的完整公式。

Step 1:准备输入

假设输入序列有 n 个词,每个词用 d 维向量表示:

输入矩阵 X ∈ ℝⁿˣᵈ

例如:4个词,每个词64维
X = [x₁]   n=4
    [x₂]   
    [x₃]   d=64
    [x₄]   

Step 2:生成 Q、K、V

通过三个可学习的权重矩阵,将输入转换为 Query、Key、Value:

Q = X · Wq    (Wq ∈ ℝᵈˣᵈᵏ)
K = X · Wk    (Wk ∈ ℝᵈˣᵈᵏ)  
V = X · Wv    (Wv ∈ ℝᵈˣᵈᵛ)

其中:

  • Wq、Wk、Wv 是可学习的参数矩阵
  • dₖ 是 Key/Query 的维度
  • dᵥ 是 Value 的维度(通常 dₖ = dᵥ = d)

为什么需要三个不同的矩阵?

这样 Q、K、V 可以从不同角度编码信息:

  • Q:这个词"想要查找"什么
  • K:这个词"对外呈现"什么
  • V:这个词"实际包含"什么

Step 3:计算注意力分数

用 Query 和 Key 的点积衡量相关性:

分数矩阵 S = Q · Kᵀ

S ∈ ℝⁿˣⁿ

S[i][j] 表示第 i 个词对第 j 个词的"关注程度"。

点积的直觉:两个向量的点积越大,说明它们方向越一致,即越相似。

Step 4:缩放(Scaling)

S_scaled = S / √dₖ

为什么要除以 √dₖ? 这是关键的细节!

假设 Q 和 K 的每个元素都是均值为 0、方差为 1 的独立随机变量。那么点积:

q · k = Σᵢ qᵢ × kᵢ

这个和的方差 = dₖ(dₖ 个独立随机变量相加)

当 dₖ 很大时(比如 64 或 512),点积的值会非常大。这会导致 softmax 后的分布变得极端(接近 one-hot),梯度消失。

除以 √dₖ 可以让方差回到 1,保持数值稳定。

Step 5:Softmax 归一化

注意力权重 A = softmax(S_scaled)

对每一行做 softmax,确保权重和为 1:

A[i][j] = exp(S_scaled[i][j]) / Σₖ exp(S_scaled[i][k])

Step 6:加权求和得到输出

输出 = A · V

每个位置的输出是所有 Value 的加权和,权重就是注意力分数。

4.3 完整公式

将上述步骤整合,得到著名的缩放点积注意力公式

                    ┌────────────────────────────┐
                    │         QKᵀ                │
Attention(Q,K,V) =  │ softmax(────) · V          │
                    │         √dₖ               │
                    └────────────────────────────┘

用更规范的数学符号表示:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

4.4 计算示例

假设我们有一个简化的例子:

输入:2个词,每个词3X = [1, 0, 1]    ← 词1
    [0, 1, 1]    ← 词2

假设 Wq = Wk = Wv = I(单位矩阵,简化计算)

则 Q = K = V = X

Step 1: 计算 QKQKᵀ = [1,0,1] [1, 0]   = [1×1+0×0+1×1, 1×0+0×1+1×1] = [2, 1]
      [0,1,1] [0, 1]     [0×1+1×0+1×1, 0×0+1×1+1×1]   [1, 2]
              [1, 1]

Step 2: 缩放(dₖ=3S_scaled = [2, 1] / √3[1.15, 0.58]
           [1, 2]         [0.58, 1.15]

Step 3: Softmax(按行)
第1行:exp(1.15)/(exp(1.15)+exp(0.58)) ≈ 0.64, 0.362行:exp(0.58)/(exp(0.58)+exp(1.15)) ≈ 0.36, 0.64

A[0.64, 0.36]
    [0.36, 0.64]

Step 4: 计算输出
输出 = A · V = [0.64, 0.36] × [1, 0, 1]
               [0.36, 0.64]   [0, 1, 1]1的新表示 = 0.64×[1,0,1] + 0.36×[0,1,1] = [0.64, 0.36, 1.0]2的新表示 = 0.36×[1,0,1] + 0.64×[0,1,1] = [0.36, 0.64, 1.0]

可以看到,每个词的新表示融合了其他所有词的信息。


五、多头注意力:并行的多视角关注

5.1 为什么需要多头注意力?

单个注意力头只能学习一种"关注模式"。但在理解语言时,我们需要同时关注多种关系:

关注角度例子
语法结构主语-谓语-宾语
指代关系"它"指代什么
语义相似同义词、近义词
位置关系相邻词的关联

多头注意力让模型同时学习多种注意力模式。

5.2 多头注意力的结构

                    ┌─────────────┐
              ┌────►│ 注意力头 1  │────┐
              │     └─────────────┘    │
              │     ┌─────────────┐    │
输入 X ───────┼────►│ 注意力头 2  │────┼───► Concat ──► 线性变换 ──► 输出
              │     └─────────────┘    │
              │     ┌─────────────┐    │
              └────►│ 注意力头 h  │────┘
                    └─────────────┘

5.3 公式推导

Step 1:分头计算

对于第 i 个头:

headᵢ = Attention(Q·Wqⁱ, K·Wkⁱ, V·Wvⁱ)

其中:

  • Wqⁱ, Wkⁱ ∈ ℝᵈˣ⁽ᵈ/ʰ⁾
  • Wvⁱ ∈ ℝᵈˣ⁽ᵈ/ʰ⁾
  • h 是头的数量

关键设计:每个头的维度是 d/h,所以 h 个头拼接后维度还是 d。

Step 2:拼接所有头

MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) · Wᵒ

其中 Wᵒ ∈ ℝᵈˣᵈ 是输出投影矩阵。

5.4 多头注意力的完整公式

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O
where headi=Attention(QWiQ,KWiK,VWiV)\text{where head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

5.5 一个直观的理解

把多头注意力想象成多个专家同时看同一个问题

  • 专家 1 关注语法结构
  • 专家 2 关注指代关系
  • 专家 3 关注情感倾向
  • ...

最后把所有专家的意见综合起来,得到更全面的理解。


六、Transformer 的完整架构

6.1 整体结构

               输入                              输出
                ↓                                 ↓
         [词嵌入 + 位置编码]                 [词嵌入 + 位置编码]
                ↓                                 ↓
         ┌─────────────┐                   ┌─────────────┐
         │   Encoder   │                   │   Decoder   │
         │  ×N 层      │ ─────────────────►│  ×N 层      │
         └─────────────┘                   └─────────────┘
                                                  ↓
                                           [线性层 + Softmax]
                                                  ↓
                                              预测输出

6.2 编码器(Encoder)结构

每个编码器层包含:

        输入
         ↓
    ┌────────────┐
    │  多头自注意力 │◄────┐
    └────────────┘     │ 残差连接
         ↓             │
       Add & Norm ─────┘
         ↓
    ┌────────────┐
    │  前馈神经网络 │◄────┐
    └────────────┘     │ 残差连接
         ↓             │
       Add & Norm ─────┘
         ↓
        输出

残差连接(Add):解决深层网络训练困难的问题(和 ResNet 思想一致)

层归一化(Norm):稳定训练,加速收敛

6.3 解码器(Decoder)结构

解码器比编码器多了一个交叉注意力层

        输入(已生成的部分)
         ↓
    ┌─────────────────┐
    │ 带掩码的多头自注意力 │◄────┐
    └─────────────────┘     │
         ↓                  │
       Add & Norm ──────────┘
         ↓
    ┌─────────────────┐
    │  交叉注意力      │◄────┐  ←── 接收编码器的输出
    │  (Encoder-Decoder)│     │
    └─────────────────┘     │
         ↓                  │
       Add & Norm ──────────┘
         ↓
    ┌────────────┐
    │  前馈神经网络 │◄────┐
    └────────────┘     │
         ↓             │
       Add & Norm ─────┘
         ↓
        输出

带掩码的自注意力(Masked Self-Attention)

在生成时,模型只能看到已经生成的词,不能"偷看"未来的词。掩码实现这一限制:

掩码矩阵(上三角为-∞):
     词12341 [ 0   -∞  -∞  -∞ ]2 [ 0    0  -∞  -∞ ]3 [ 0    0   0  -∞ ]4 [ 0    0   0   0 ]

加到注意力分数上后,-∞ 经过 softmax 变成 0,即完全不关注未来的词。

七、位置编码:让模型知道顺序

7.1 为什么需要位置编码?

注意力机制是排列不变的

Attention("我爱中国") = Attention("国中爱我")  ?

如果不加处理,模型无法区分词的顺序!但语言中顺序至关重要。

7.2 位置编码的设计

Transformer 使用正弦/余弦位置编码

PE(pos, 2i)   = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))

其中:

  • pos:词在序列中的位置(0, 1, 2, ...)
  • i:维度的索引(0, 1, 2, ..., d/2-1)
  • d:编码维度

为什么用正弦/余弦?

  1. 有界性:值始终在 [-1, 1] 范围内
  2. 周期性:可以表达相对位置关系
  3. 可外推:理论上可以处理任意长度的序列

关键性质:对于固定的偏移量 k,PE(pos+k) 可以表示为 PE(pos) 的线性函数。这让模型容易学习相对位置关系。

7.3 位置编码的可视化

位置
  ↓
  0  ━━━━━━━━━━━━━━━━━━━━
  1  ━━━━━━━━━━━━━━━━━━━∙
  2  ━━━━━━━━━━━━━━━━━━∙∙
  3  ━━━━━━━━━━━━━━━━━∙∙∙
  ...
     低频 ←────────────→ 高频
           维度方向
  • 低维度变化慢(表示粗粒度位置)
  • 高维度变化快(表示细粒度位置)

八、前馈神经网络(FFN)

每个 Transformer 层还包含一个前馈神经网络:

FFN(x) = max(0, xW₁ + b₁)W₂ + b

或者用更现代的 GELU 激活:

FFN(x) = GELU(xW₁ + b₁)W₂ + b

结构

输入(d 维)
    ↓
[线性层] → d → 4d(扩展)
    ↓
[激活函数 ReLU/GELU][线性层]4d → d(收缩)
    ↓
输出(d 维)

FFN 的作用

  • 增加模型容量
  • 引入非线性
  • 每个位置独立处理(可并行)

为什么中间层要扩大 4 倍?

这是经验设计。扩大维度增加了模型的表达能力,是参数量和效果的权衡。


九、训练与推理

9.1 训练过程

任务:给定输入序列,预测下一个词

损失函数:交叉熵

Loss = -Σ log P(正确词 | 上下文)

优化技巧

  • 学习率预热(Warmup):开始时用小学习率,逐渐增大,然后再衰减
  • Dropout:随机丢弃部分神经元,防止过拟合
  • 标签平滑(Label Smoothing):软化目标分布,提高泛化能力

9.2 推理过程(生成)

自回归生成

输入:"今天天气"

步骤1:模型预测下一个词 → "很"
步骤2:输入变成 "今天天气很",预测 → "好"
步骤3:输入变成 "今天天气很好",预测 → [结束符]

输出:"今天天气很好"

每一步都要重新计算,这就是为什么生成很慢。

KV 缓存(KV Cache)

为了加速,可以缓存已计算的 Key 和 Value,避免重复计算:

不用缓存:每步计算 O(n²),总共 O(n³)
使用缓存:每步计算 O(n),总共 O(n²)

十、Transformer 的变体与发展

10.1 主流架构

架构类型代表模型特点
Encoder-onlyBERT双向理解,适合分类、问答
Decoder-onlyGPT单向生成,适合文本生成
Encoder-DecoderT5, BART完整架构,适合翻译、摘要

10.2 效率优化

优化方法思路代表工作
稀疏注意力只关注部分位置Longformer, BigBird
线性注意力将复杂度从 O(n²) 降到 O(n)Linear Transformer
Flash Attention优化内存访问模式Flash Attention
分组查询注意力多个头共享 K/VGQA

10.3 架构改进

  • RMSNorm:替代 LayerNorm,更高效
  • 旋转位置编码(RoPE):替代绝对位置编码,外推能力更强
  • SwiGLU 激活:替代 ReLU/GELU,效果更好

十一、总结

让我们回顾 Transformer 的核心要点:

组件作用创新点
自注意力建立序列内任意位置的关系并行计算、长距离建模
多头注意力多视角关注不同模式增加表达能力
位置编码引入位置信息正弦编码,可外推
残差连接训练深层网络梯度直通
层归一化稳定训练加速收敛

Transformer 成功的关键

  1. 并行性:摆脱 RNN 的顺序依赖
  2. 长距离建模:注意力机制直接连接任意位置
  3. 可扩展性:架构简洁,易于堆叠更多层

核心公式速记

# 自注意力
Attention(Q,K,V) = softmax(QKᵀ/√dₖ) × V

# 多头注意力  
MultiHead(Q,K,V) = Concat(head₁,...,headₕ)Wᵒ
where headᵢ = Attention(QWᵢQ, KWᵢK, VWᵢV)

延伸阅读

如果你想进一步深入:

  1. 论文原文Attention Is All You Need(Vaswani et al., 2017)
  2. 经典解读The Illustrated Transformer(Jay Alammar)
  3. 代码实践The Annotated Transformer(Harvard NLP)
  4. 进阶内容FlashAttentionRotary Position Embedding 等优化技术

后记:Transformer 的设计之美在于其简洁性。"Attention Is All You Need"这个标题本身就是一种宣言:用一个统一的机制(注意力)替代复杂的循环结构、门控机制,反而取得了更好的效果。这告诉我们,有时候化繁为简才是真正的创新。