我们在第一课讲述了Transformer的基本原理,这一课,我们用代码实现注意力机制。
注意力机制代码
nn.Module 是PyTorch模型的一个基本构建块,为模型层的创建和管理提供必要的功能。
注意力的本质是全局信息加权和。记得在第一课,我们用Q(query), K(key), V(value)来构建注意力打分。
自注意力机制
下面代码来自[1]: 一个非常简化的自注意类。
为什么叫做自注意力机制? 因为这里Q, K, V 都来自同一个输入x.
import torch.nn as nn
class SelfAttention_v1(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_query = nn.Parameter(torch.rand(d_in, d_out))
self.W_key = nn.Parameter(torch.rand(d_in, d_out))
self.W_value = nn.Parameter(torch.rand(d_in, d_out))
def forward(self, x):
keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
## 使用方法如下:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))
在介绍因果注意力机制之前,我们优化这个简单的版本。
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
v2 更好在:参数初始化更合理 + 支持 bias + 更通用的输入维度 + 更工程化易扩展。
因果注意力机制
因果注意力 = 只能看当前和过去,不能看未来。
由于掩码遮蔽了未来,我们看到tensor看起来这个样子。
✓ 0 0 0
✓ ✓ 0 0
✓ ✓ ✓ 0
✓ ✓ ✓ ✓
下面,我们来实现一个简单的因果注意力机制。
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
# 输出特征维度(即 d_k)
self.d_out = d_out
# 线性映射:X -> Q, K, V
# 对应公式:
# Q = XW_Q
# K = XW_K
# V = XW_V
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# 注意力权重的 dropout(Transformer 标准做法)
self.dropout = nn.Dropout(dropout)
# 构造因果 Mask(上三角矩阵,不含对角线)
# 形状: (context_length, context_length)
# 作用:
# mask[i, j] = 1 表示 j > i (未来位置)
# mask[i, j] = 0 表示 j <= i(当前或过去)
#
# 后续会把 mask==1 的位置填成 -inf,
# 从而保证 softmax 后这些位置权重为 0
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
"""
x 形状: (batch_size, num_tokens, d_in)
"""
b, num_tokens, d_in = x.shape # b: batch_size, T: 序列长度
# ========= 第一步:线性映射得到 Q, K, V =========
# 形状: (b, T, d_out)
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# ========= 第二步:计算注意力分数 QK^T =========
# keys.transpose(1, 2) 变为 (b, d_out, T)
# 乘完后 attn_scores 形状: (b, T, T)
#
# 对应公式:
# S = QK^T
# 注意这里不直接用keys.T是因为droput在矩阵里面。我们只需要转置前面两个。
attn_scores = queries @ keys.transpose(1, 2)
# ========= 第三步:加入因果 Mask =========
# 将未来位置 (j > i) 填为 -inf
# softmax 后这些位置权重会变成 0
#
# 对应公式:
# S' = S + M
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens],
-torch.inf
)
# ========= 第四步:缩放 + softmax =========
# 除以 sqrt(d_k) 防止数值过大
#
# 对应公式:
# A = softmax(S' / sqrt(d_k))
attn_weights = torch.softmax(
attn_scores / (keys.shape[-1] ** 0.5),
dim=-1
)
# 对注意力权重做 dropout
attn_weights = self.dropout(attn_weights)
# ========= 第五步:加权求和 =========
# Z = A V
# 输出形状: (b, T, d_out)
context_vec = attn_weights @ values
return context_vec
关于masked_fill这个函数,我们简单说明下。
x = torch.tensor([[1., 2.],
[3., 4.]])
mask = torch.tensor([[False, True],
[False, False]])
x.masked_fill_(mask, -1)
print(x)
## 输出
tensor([[ 1., -1.],
[ 3., 4.]])
发生了什么?
- mask 为 True 的位置是
(0,1) - 那个位置被改成
-1 - 其他位置不变
self.mask.bool()[:num_tokens, :num_tokens], 按顺序解释:
self.mask→ 取 tensor.bool()→ 转成布尔 tensor[:num_tokens, :num_tokens]→ 取左上角 num_tokens × num_tokens 子矩阵
多头注意力机制
Multi-Head Attention 数学表达式
设输入为:
头数为:
每个头的维度为:
1. 线性投影
对于每个 head :
其中:
2. 每个 head 的注意力
3. 拼接所有 head
4. 输出投影
其中:
最终整体表达式
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
# 每个 head 的维度
# 多头的核心思想:把总维度 d_out 拆成 num_heads 个子空间
self.head_dim = d_out // num_heads
# ==========================
# 为什么需要 projection?
# ==========================
# 如果直接用 X 做 attention:
# Q = K = V = X
# 那么所有 head 都在同一个特征空间里做注意力,
# 表达能力会受到限制。
#
# 加入可学习的线性映射后:
# Q = X W_Q
# K = X W_K
# V = X W_V
#
# 模型可以学习:
# - 哪些特征用于“查询”
# - 哪些特征用于“匹配”
# - 哪些特征用于“输出”
#
# 多头的关键:
# 每个 head 使用不同的 W_Q, W_K, W_V,
# 相当于在不同的线性子空间中做 attention。
#
# 这就是 multi-head 比 single-head 强的根本原因。
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# 输出投影层
# 作用:
# 把多个 head 的结果拼接后再做一次线性变换,
# 让不同 head 之间的信息进行融合。
#
# 如果没有这一层,各个 head 是彼此独立的。
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
# 因果 mask(上三角)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
# ==========================
# 1. 线性投影得到 Q, K, V
# ==========================
# 形状: (b, T, d_out)
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# ==========================
# 2. 拆成多个 head
# ==========================
# (b, T, d_out)
# -> (b, T, num_heads, head_dim)
#
# 本质:把一个大向量拆成多个小向量,
# 每个 head 在自己的子空间中做注意力。
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
# ==========================
# 3. 调整维度顺序
# ==========================
# (b, T, h, d)
# -> (b, h, T, d)
#
# 这样可以对每个 head 并行计算注意力
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# ==========================
# 4. 每个 head 计算注意力
# ==========================
# (b, h, T, d) @ (b, h, d, T)
# -> (b, h, T, T)
attn_scores = queries @ keys.transpose(2, 3)
# 因果 mask(未来位置置为 -inf)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
# 缩放 + softmax
attn_weights = torch.softmax(
attn_scores / (self.head_dim ** 0.5),
dim=-1
)
attn_weights = self.dropout(attn_weights)
# ==========================
# 5. 加权求和
# ==========================
# (b, h, T, T) @ (b, h, T, d)
# -> (b, h, T, d)
context = attn_weights @ values
# ==========================
# 6. 合并多个 head
# ==========================
# (b, h, T, d)
# -> (b, T, h, d)
context = context.transpose(1, 2)
# -> (b, T, d_out)
context = context.contiguous().view(b, num_tokens, self.d_out)
# ==========================
# 7. 输出投影
# ==========================
# 融合多个 head 的信息
context = self.out_proj(context)
return context
参考:
[1].《从零构建大模型》