Transformer 与 Self-Attention (整理版)
目标:用“概念 → 公式 → 代码 → 练习”的方式,把 Transformer 的核心机制讲清楚。
目录
- [Transformer 与 Self-Attention (整理版)](#transformer-与-self-attention- 整理版)
1. Transformer 是什么
Transformer 是一种序列建模架构(2017 年提出),它把“序列之间的依赖”主要交给注意力机制来建模,而不是依赖 RNN 的时间步递推。
它之所以重要,核心在两点:
- 并行性:训练时可以对整段序列并行计算注意力(比 RNN 更容易吃满 GPU/TPU)。
- 长程依赖:任意两个 token 之间的交互路径更短(自注意力是一次“全连接式”交互)。
一句话直觉:
对于序列中的每个位置,让模型学会“我应该关注哪些位置,以及关注多少”。
2. 编码器/解码器整体结构
经典 Transformer(seq2seq)包含两大块:
- Encoder(编码器):把输入序列编码为一组上下文表示。
- Decoder(解码器):在生成第 个输出 token 时,只能看见 之前已生成的内容,并结合 Encoder 输出进行交互(cross-attention)。
每一块通常由 层堆叠(论文里常见 )。各层结构相同,但参数不共享。
一个 Encoder Layer 通常包含:
- Multi-Head Self-Attention
- Add & Norm
- Feed-Forward Network (FFN)
- Add & Norm
一个 Decoder Layer 通常包含:
- Masked Multi-Head Self-Attention(遮住未来)
- Add & Norm
- Multi-Head Cross-Attention(Q 来自 decoder,K/V 来自 encoder)
- Add & Norm
- FFN
- Add & Norm
3. Self-Attention(自注意力)
3.1 Q/K/V 的含义(非常实用的直觉)
- Query(Q):我“想找什么信息”。
- Key(K):我“是什么信息的索引/标签”。
- Value(V):我“真正携带的内容”。
同一个输入 会被线性映射出 :
其中 (长度 ,隐藏维 )。
3.2 Scaled Dot-Product Attention 公式
注意力权重来自相似度(点积)并做缩放:
为什么要除以 :当维度较大时点积数值更容易变大,softmax 会更“尖”,梯度更不稳定;缩放能让训练更稳。
3.3 输出在做什么
对每个位置 :
- 先算它对所有位置 的相关性分数
- softmax 得到权重
- 用权重对所有 value 做加权求和,得到该位置的新表示
4. Multi-Head Attention(多头注意力)
单头注意力只能在一个“子空间”里做匹配。多头注意力的做法是:
- 用 组不同的线性映射得到
- 每个头独立算 attention 得到
- 把各头拼接后再做一次线性变换
直觉:不同的头可以分别学“指代关系”“语法依赖”“长距离对齐”等不同模式。
5. 位置编码(Positional Encoding)
自注意力本身对输入顺序不敏感(你把 token 乱序,注意力计算形式不变)。因此需要显式注入位置信息。
常见做法:把位置向量 与词向量相加:
论文中使用的正弦/余弦位置编码:
补充:很多实现也会用可学习位置编码(learnable embeddings),同样有效。
6. Add & Norm(残差 + LayerNorm)
Transformer 里几乎每个子层都采用:
- 残差连接:让信息与梯度更容易流动,深层训练更稳定。
- LayerNorm:对单样本特征维做归一化,NLP 中通常比 BatchNorm 更合适。
7. FFN(前馈网络)
FFN 是逐位置(position-wise)的两层 MLP:
它不在 token 间交互(交互在 attention 里做),但能增强非线性表达能力。
8. Mask:Padding Mask 与 Causal Mask
8.1 Padding Mask
批处理时序列会 padding 到同一长度。padding token 不应该被关注,因此要把这些位置的 attention logits 设为 (实现里通常是一个足够小的负数)。
8.2 Causal Mask(Decoder 的“不能看未来”)
自回归生成时,第 个位置不能看见 之后的位置,所以要加上一个上三角 mask。
9. 用 PyTorch 手写一次 Self-Attention
下面用一个小矩阵例子把公式跑通(重点是理解矩阵形状与步骤)。
9.1 准备输入
import torch
from torch.nn.functional import softmax
x = torch.tensor(
[
[1, 0, 1, 0], # token 1 embedding
[0, 2, 0, 2], # token 2 embedding
[1, 1, 1, 1], # token 3 embedding
],
dtype=torch.float32,
)
print(x)
输出:
tensor([[1., 0., 1., 0.],
[0., 2., 0., 2.],
[1., 1., 1., 1.]])
9.2 构造 Q/K/V 映射矩阵
说明:真实模型里 是可训练参数,这里为了可复现,用手写的小矩阵。
w_key = torch.tensor(
[
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0],
],
dtype=torch.float32,
)
w_query = torch.tensor(
[
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1],
],
dtype=torch.float32,
)
w_value = torch.tensor(
[
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0],
],
dtype=torch.float32,
)
print("w_key\n", w_key)
print("w_query\n", w_query)
print("w_value\n", w_value)
9.3 计算 K/Q/V
keys = x @ w_key
queries = x @ w_query
values = x @ w_value
print("keys\n", keys)
print("queries\n", queries)
print("values\n", values)
9.4 计算注意力分数(logits)
这里用最简单的 (完整版本还要除以 ):
attn_logits = queries @ keys.T
print(attn_logits)
9.5 softmax 得到权重
attn_weights = softmax(attn_logits, dim=-1)
print(attn_weights)
为了更直观看清“加权求和”,我们也可以做一个近似版本(教学用):
attn_weights_simple = torch.tensor(
[
[0.0, 0.5, 0.5],
[0.0, 1.0, 0.0],
[0.0, 0.9, 0.1],
],
dtype=torch.float32,
)
9.6 加权求和得到输出
标准矩阵写法是:
output = attn_weights @ values
print(output)
如果你想看“每个 value 被乘了多少”,可以像下面这样拆开(便于教学观察):
weighted_values = values[:, None] * attn_weights_simple.T[:, :, None]
print(weighted_values)
print("sum over tokens ->", weighted_values.sum(dim=0))
10. 优缺点与常见坑
优点
- 效果强:尤其在大数据与大模型规模下。
- 并行友好:训练吞吐高。
- 长距离依赖更容易学到。
常见坑
- mask 忘了加:padding token 参与注意力会污染表示;decoder 不加 causal mask 会“偷看答案”。
- shape 搞混:批次维、头数维、序列长度维容易写错。
- softmax 维度写错:通常要对最后一维(key 维/序列维)做 softmax。
11. 小练习
- 把第 9 节的示例改成带缩放:将
attn_logits替换为attn_logits / (dk ** 0.5),其中dk = keys.size(-1)。 - 写一个 causal mask(上三角),把未来位置 logits 置为一个很小的负数(如
-1e9),观察输出变化。 - 把
values改大一倍,看看输出是否也线性变大(应该会)。
12. 延伸阅读
- 《Attention Is All You Need》:Transformer 原论文(arXiv)
- The Illustrated Transformer(图解 Transformer,直觉非常好)
- PyTorch 官方
torch.nn.MultiheadAttention文档与源码(理解工程实现细节)