Transformer 与 Self-Attention (整理版)

17 阅读4分钟

Transformer 与 Self-Attention (整理版)

目标:用“概念 → 公式 → 代码 → 练习”的方式,把 Transformer 的核心机制讲清楚。

目录


1. Transformer 是什么

Transformer 是一种序列建模架构(2017 年提出),它把“序列之间的依赖”主要交给注意力机制来建模,而不是依赖 RNN 的时间步递推。

它之所以重要,核心在两点:

  • 并行性:训练时可以对整段序列并行计算注意力(比 RNN 更容易吃满 GPU/TPU)。
  • 长程依赖:任意两个 token 之间的交互路径更短(自注意力是一次“全连接式”交互)。

一句话直觉:

对于序列中的每个位置,让模型学会“我应该关注哪些位置,以及关注多少”。


2. 编码器/解码器整体结构

经典 Transformer(seq2seq)包含两大块:

  • Encoder(编码器):把输入序列编码为一组上下文表示。
  • Decoder(解码器):在生成第 tt 个输出 token 时,只能看见 tt 之前已生成的内容,并结合 Encoder 输出进行交互(cross-attention)。

每一块通常由 NN 层堆叠(论文里常见 N=6N=6)。各层结构相同,但参数不共享。

一个 Encoder Layer 通常包含:

  1. Multi-Head Self-Attention
  2. Add & Norm
  3. Feed-Forward Network (FFN)
  4. Add & Norm

一个 Decoder Layer 通常包含:

  1. Masked Multi-Head Self-Attention(遮住未来)
  2. Add & Norm
  3. Multi-Head Cross-Attention(Q 来自 decoder,K/V 来自 encoder)
  4. Add & Norm
  5. FFN
  6. Add & Norm

3. Self-Attention(自注意力)

3.1 Q/K/V 的含义(非常实用的直觉)

  • Query(Q):我“想找什么信息”。
  • Key(K):我“是什么信息的索引/标签”。
  • Value(V):我“真正携带的内容”。

同一个输入 XX 会被线性映射出 Q,K,VQ,K,V

Q=XWQ,K=XWK,V=XWVQ = XW_Q,\quad K = XW_K,\quad V = XW_V

其中 XRL×dmodelX \in \mathbb{R}^{L\times d_{model}}(长度 LL,隐藏维 dmodeld_{model})。

3.2 Scaled Dot-Product Attention 公式

注意力权重来自相似度(点积)并做缩放:

Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

为什么要除以 dk\sqrt{d_k}:当维度较大时点积数值更容易变大,softmax 会更“尖”,梯度更不稳定;缩放能让训练更稳。

3.3 输出在做什么

对每个位置 ii

  • 先算它对所有位置 jj 的相关性分数 sijs_{ij}
  • softmax 得到权重 aija_{ij}
  • 用权重对所有 value 做加权求和,得到该位置的新表示

4. Multi-Head Attention(多头注意力)

单头注意力只能在一个“子空间”里做匹配。多头注意力的做法是:

  1. hh 组不同的线性映射得到 Qi,Ki,ViQ_i,K_i,V_i
  2. 每个头独立算 attention 得到 ZiZ_i
  3. 把各头拼接后再做一次线性变换
MHA(X)=Concat(Z1,,Zh)WO\mathrm{MHA}(X)=\mathrm{Concat}(Z_1,\dots,Z_h)W_O

直觉:不同的头可以分别学“指代关系”“语法依赖”“长距离对齐”等不同模式。


5. 位置编码(Positional Encoding)

自注意力本身对输入顺序不敏感(你把 token 乱序,注意力计算形式不变)。因此需要显式注入位置信息。

常见做法:把位置向量 PE(pos)PE(pos) 与词向量相加:

X=X+PEX' = X + PE

论文中使用的正弦/余弦位置编码:

PE(pos,2i)=sin(pos100002i/dmodel),PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right),\quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

补充:很多实现也会用可学习位置编码(learnable embeddings),同样有效。


6. Add & Norm(残差 + LayerNorm)

Transformer 里几乎每个子层都采用:

Y=LayerNorm(X+Sublayer(X))Y = \mathrm{LayerNorm}(X + \mathrm{Sublayer}(X))
  • 残差连接:让信息与梯度更容易流动,深层训练更稳定。
  • LayerNorm:对单样本特征维做归一化,NLP 中通常比 BatchNorm 更合适。

7. FFN(前馈网络)

FFN 是逐位置(position-wise)的两层 MLP:

FFN(x)=σ(xW1+b1)W2+b2\mathrm{FFN}(x)=\sigma(xW_1+b_1)W_2+b_2

它不在 token 间交互(交互在 attention 里做),但能增强非线性表达能力。


8. Mask:Padding Mask 与 Causal Mask

8.1 Padding Mask

批处理时序列会 padding 到同一长度。padding token 不应该被关注,因此要把这些位置的 attention logits 设为 -\infty(实现里通常是一个足够小的负数)。

8.2 Causal Mask(Decoder 的“不能看未来”)

自回归生成时,第 tt 个位置不能看见 tt 之后的位置,所以要加上一个上三角 mask。


9. 用 PyTorch 手写一次 Self-Attention

下面用一个小矩阵例子把公式跑通(重点是理解矩阵形状与步骤)。

9.1 准备输入

import torch
from torch.nn.functional import softmax

x = torch.tensor(
  [
    [1, 0, 1, 0],  # token 1 embedding
    [0, 2, 0, 2],  # token 2 embedding
    [1, 1, 1, 1],  # token 3 embedding
  ],
  dtype=torch.float32,
)

print(x)

输出:

tensor([[1., 0., 1., 0.],
    [0., 2., 0., 2.],
    [1., 1., 1., 1.]])

9.2 构造 Q/K/V 映射矩阵

说明:真实模型里 WQ,WK,WVW_Q,W_K,W_V 是可训练参数,这里为了可复现,用手写的小矩阵。

w_key = torch.tensor(
  [
    [0, 0, 1],
    [1, 1, 0],
    [0, 1, 0],
    [1, 1, 0],
  ],
  dtype=torch.float32,
)
w_query = torch.tensor(
  [
    [1, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 1],
  ],
  dtype=torch.float32,
)
w_value = torch.tensor(
  [
    [0, 2, 0],
    [0, 3, 0],
    [1, 0, 3],
    [1, 1, 0],
  ],
  dtype=torch.float32,
)

print("w_key\n", w_key)
print("w_query\n", w_query)
print("w_value\n", w_value)

9.3 计算 K/Q/V

keys = x @ w_key
queries = x @ w_query
values = x @ w_value

print("keys\n", keys)
print("queries\n", queries)
print("values\n", values)

9.4 计算注意力分数(logits)

这里用最简单的 QKTQK^T(完整版本还要除以 dk\sqrt{d_k}):

attn_logits = queries @ keys.T
print(attn_logits)

9.5 softmax 得到权重

attn_weights = softmax(attn_logits, dim=-1)
print(attn_weights)

为了更直观看清“加权求和”,我们也可以做一个近似版本(教学用):

attn_weights_simple = torch.tensor(
  [
    [0.0, 0.5, 0.5],
    [0.0, 1.0, 0.0],
    [0.0, 0.9, 0.1],
  ],
  dtype=torch.float32,
)

9.6 加权求和得到输出

标准矩阵写法是:

output = attn_weights @ values
print(output)

如果你想看“每个 value 被乘了多少”,可以像下面这样拆开(便于教学观察):

weighted_values = values[:, None] * attn_weights_simple.T[:, :, None]
print(weighted_values)
print("sum over tokens ->", weighted_values.sum(dim=0))

10. 优缺点与常见坑

优点

  • 效果强:尤其在大数据与大模型规模下。
  • 并行友好:训练吞吐高。
  • 长距离依赖更容易学到。

常见坑

  • mask 忘了加:padding token 参与注意力会污染表示;decoder 不加 causal mask 会“偷看答案”。
  • shape 搞混:批次维、头数维、序列长度维容易写错。
  • softmax 维度写错:通常要对最后一维(key 维/序列维)做 softmax。

11. 小练习

  1. 把第 9 节的示例改成带缩放:将 attn_logits 替换为 attn_logits / (dk ** 0.5),其中 dk = keys.size(-1)
  2. 写一个 causal mask(上三角),把未来位置 logits 置为一个很小的负数(如 -1e9),观察输出变化。
  3. values 改大一倍,看看输出是否也线性变大(应该会)。

12. 延伸阅读

  • 《Attention Is All You Need》:Transformer 原论文(arXiv)
  • The Illustrated Transformer(图解 Transformer,直觉非常好)
  • PyTorch 官方 torch.nn.MultiheadAttention 文档与源码(理解工程实现细节)