前言
我们在《从零训练大模型之模型搭建》这篇文章中,按照《Attention Is All You Need》的内容进行了模型实现,也训练出了一个pre-train模型。也在上一篇文章《从零训练大模型之模型升级版搭建及训练(上)》文章中,我们指出了Llama、Mistral 这些强大的开源大模型使用的新技术,但是该如何编写代码呢?
别担心,这篇文章就是为升级我们之前的模型而准备的!我们将带领大家,亲手将一个“教科书式”的 Transformer Decoder-only 模型,一步步升级为集成了 FlashAttention、GQA、RoPE、SwiGLU、RMSNorm 等前沿技术的“准现代”LLM 架构。
我们的起点:一个“经典但过时”的Transformer
在开始升级之前,让我们先看看我们的“旧装备”——一个根据《Attention Is All You Need》论文内容用PyTorch标准API构建的Decoder-only模型。它包含了以下几个经典组件:
- nn.Embedding + Sinusoidal Positional Encoding:标准的词嵌入加上绝对位置编码,通过给每个位置一个固定的“坐标”来注入位置信息。
- nn.MultiheadAttention:PyTorch 官方实现的多头自注意力,is_causal=True 确保了它只能看到过去的信息。
- nn.LayerNorm + nn.ReLU Feed-Forward:标准的层归一化和基于 ReLU 的前馈网络。
- Pre-Norm 架构:先进行归一化,再送入子层(注意力或前馈网络),有助于训练稳定。
这个模型是我们这段时间学习Transformer的起点和成果,但它在效率和性能上与Llama等现代模型存在代差。为什么呢?
- 计算/内存瓶颈:标准注意力的计算和内存复杂度是序列长度的平方(O(N²)),在处理长文本时会迅速成为瓶颈。
- 位置编码的局限:绝对位置编码在处理超过训练长度的序列时,泛化能力较差。
- 组件效率:LayerNorm 和 ReLU 虽然经典,但已有计算更简单、效果可能更好的替代品。
我们的目标,就是用现代化的组件替换这些经典部件,让模型“鸟枪换炮”。
逐个击破:五大核心技术升级
现在,让我们开始对模型的五大核心部分进行现代化改造。
1. RMSNorm:更快、更简单的归一化
为什么需要它?
LayerNorm 是 Transformer 的标配,但它的计算涉及减去均值和除以标准差,略显复杂。RMSNorm (Root Mean Square Normalization) 提出了一种简化方案:只通过均方根对输入进行缩放,去掉了减去均值的步骤。
- 公式对比:
- LayerNorm:
- RMSNorm:
优势: 计算量更小,实验证明在同样效果下速度更快。Llama 系列模型就全面采用了 RMSNorm。
代码实现如下:
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# gamma 参数
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算均方根 (rsqrt是平方根倒数,更高效)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
我们用这个 RMSNorm 类替换掉模型中所有的 nn.LayerNorm。
2. RoPE:更优雅的旋转式位置编码
为什么需要它?
传统的正弦位置编码是“绝对”的,它给每个 token 一个固定的位置ID。而 RoPE (Rotary Positional Embedding) 是一种“相对”位置编码。它的核心思想非常精妙:位置信息不应该通过“加法”注入,而应该通过“旋转”注入。
想象一下,Query 和 Key 向量是二维平面上的点。RoPE 根据它们的位置 m 和 n,将它们分别旋转一个角度。当计算它们的点积(注意力分数)时,结果只与它们的相对位置 m-n 有关,而与绝对位置无关。这使得模型能更好地理解词与词之间的相对距离,并且在处理超长序列时具有更好的泛化性。
代码实现:
RoPE 的实现分为两步:预计算旋转角度和应用旋转。
# 1. 预计算旋转矩阵(用 cos 和 sin 表示)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int, base: int = 10000, device: Optional[torch.device] = None):
super().__init__()
# 计算不同维度的旋转频率
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# 预计算所有位置的 cos 和 sin 值
t = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# 缓存起来,避免重复计算
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x: torch.Tensor):
seq_len = x.shape[-2]
return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]
# 2. 应用旋转的辅助函数
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
在模型的前向传播中,我们会先生成 cos 和 sin 值,然后传递给每个注意力层,应用到 Query 和 Key 上。
3. SwiGLU:更智能的前馈网络
为什么需要它?
标准的前馈网络(FFN)通常是 Linear -> ReLU -> Linear。SwiGLU 是一种改进的 FFN 结构,它引入了门控机制。
- 公式:SwiGLU(x, W, V, W2) = (Swish(x @ W) * (x @ V)) @ W2
- Swish(x) = x * sigmoid(x),在 PyTorch 中是 F.silu。
- x @ W 的结果通过 Swish 激活,而 x @ V 的结果充当一个“门”,决定了前一部分信息有多少可以通过。
优势: 这种门控机制让网络可以动态地控制信息流,实验表明它能带来显著的性能提升。PaLM、Llama 和 Mistral 等模型都采用了类似 SwiGLU 的结构。
代码实现:
class SwiGLUFeedForward(nn.Module):
def __init__(self, embed_dim: int, ff_dim: int, dropout: float = 0.0):
super().__init__()
self.w1 = nn.Linear(embed_dim, ff_dim, bias=False) # 门控
self.w2 = nn.Linear(ff_dim, embed_dim, bias=False) # 输出
self.w3 = nn.Linear(embed_dim, ff_dim, bias=False) # up projection
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
4. GQA & FlashAttention:为注意力机制装上核动力引擎
这是我们本次升级中最核心、最关键的一步!
GQA (Grouped-Query Attention)
- 为什么需要它?
- 在多头注意力(MHA) 中,每个 Query 头都有自己独立的 Key 和 Value 头。
- 在推理时,模型需要缓存所有 token 的 Key 和 Value(即 KV Cache)来加速生成。对于 MHA,KV Cache 的大小与头的数量成正比,非常消耗显存。
- 多查询注意力(MQA) 提出让所有的 Query 头共享同一组 Key 和 Value 头,极大减少了 KV Cache,但可能导致性能下降。
- 分组查询注意力(GQA) 是两者的完美折中:它将 Query 头分组,组内的 Query 头共享同一组 Key 和 Value 头。
- 例如,如果有 8 个 Query 头和 2 个 KV 头,那么每 4 个 Query 头会共享一组 KV。这既大幅减少了 KV Cache,又保持了足够高的模型质量。Mistral 7B 就是 GQA 的著名应用案例。
FlashAttention
- 为什么需要它?
- 标准注意力的最大问题是需要计算并存储一个 (seq_len, seq_len) 大小的注意力矩阵。当序列很长时(如 8k、16k),这个矩阵会变得异常巨大,远远超出 GPU SRAM(片上高速缓存)的容量,导致频繁地与 HBM(高带宽显存)进行数据读写,而这个过程非常缓慢,成为性能瓶颈。
- FlashAttention 是一种 I/O 感知的注意力算法。它巧妙地将计算过程分块(Tiling),在 GPU 的 SRAM 中完成每个小块的注意力计算,而无需将整个巨大的注意力矩阵写入 HBM。
- 结果:在不改变数学计算结果的前提下,实现了数量级的加速和显存节省。可以说,没有 FlashAttention,就没有今天的大模型长文本时代。
代码实现:
我们将 GQA 和 FlashAttention 结合在一个模块中。
# 引入 flash_attn
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
from torch.nn.functional import scaled_dot_product_attention
class FlashGQA(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, num_kv_heads: int, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = embed_dim // num_heads
self.num_q_per_kv = num_heads // num_kv_heads
self.wq = nn.Linear(embed_dim, embed_dim, bias=False)
self.wk = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(embed_dim, embed_dim, bias=False)
self.dropout = dropout
def forward(self, x, freqs_cos, freqs_sin, key_padding_mask=None):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# Reshape to (bs, seq_len, num_heads, head_dim)
xq = xq.view(bsz, seq_len, self.num_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.num_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.num_kv_heads, self.head_dim)
# 应用 RoPE
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
# FlashAttention
if FLASH_ATTENTION_AVAILABLE:
# GQA: 重复 K/V 头以匹配 Q 头
keys_repeated = xk.repeat_interleave(self.num_q_per_kv, dim=2)
values_repeated = xv.repeat_interleave(self.num_q_per_kv, dim=2)
attn_output = flash_attn_func(
q=xq, k=keys_repeated, v=values_repeated,
dropout_p=self.dropout if self.training else 0.0, causal=True
)
else: # Fallback to PyTorch's implementation
# ... (备用方案代码,详见完整代码)
# 同样需要重复K/V头
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2).repeat_interleave(self.num_q_per_kv, dim=1)
xv = xv.transpose(1, 2).repeat_interleave(self.num_q_per_kv, dim=1)
attn_output = scaled_dot_product_attention(
xq, xk, xv, is_causal=True,
dropout_p=self.dropout if self.training else 0.0
).transpose(1, 2)
attn_output = attn_output.view(bsz, seq_len, -1)
return self.wo(attn_output)
终极合体:组装我们的现代化 LLM
现在,我们已经打造好了所有先进的组件。是时候把它们组装起来,构成我们最终的、现代化的 DecoderOnlyTransformer 了!
升级后的 DecoderBlock:
class DecoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, num_kv_heads, ff_dim, dropout):
super().__init__()
# 使用新组件
self.self_attention = FlashGQA(embed_dim, num_heads, num_kv_heads, dropout)
self.feed_forward = SwiGLUFeedForward(embed_dim, ff_dim, dropout)
self.norm1 = RMSNorm(embed_dim)
self.norm2 = RMSNorm(embed_dim)
def forward(self, x, freqs_cos, freqs_sin, key_padding_mask=None):
# 结构不变,但内部调用已全部升级
residual = x
h = self.norm1(x)
h = self.self_attention(h, freqs_cos, freqs_sin, key_padding_mask)
x = residual + h
residual = x
h = self.norm2(x)
h = self.feed_forward(h)
x = residual + h
return x
最终的 DecoderOnlyTransformer:
主模型的变化在于:
- 初始化:不再需要 PositionalEncoding,取而代之的是 RotaryEmbedding。
- forward 流程:在进入 DecoderBlock 循环前,先计算好 RoPE 需要的 freqs_cos 和 freqs_sin,然后将它们传递给每一层。
- 最终输出:在输出到 output_layer 之前,增加一个 RMSNorm,这是 Llama 等模型的常见做法。
使用指南
将如上内容替换到MiniLlmsModel文件中,即可完成升级。若想获取完整的模型代码,可以参考附录的GitHub连接。
如何使用
在我们的训练脚本(如 train_pretrain.py)中,我们只需要做两件小事:
- 从 MiniLlmsModel_optimized 导入 DecoderOnlyTransformer。
- 在实例化模型时,传入新的 num_kv_heads 参数。
from MiniLlmsModel_optimized import DecoderOnlyTransformer
# 定义 GQA 参数
NUM_KV_HEADS = 4 # 例如,8个Query头,4个KV头
FF_DIM = int(2/3 * 4 * EMBED_DIM) # SwiGLU 推荐维度
model = DecoderOnlyTransformer(
# ... 其他参数 ...
num_heads=NUM_HEADS,
num_kv_heads=NUM_KV_HEADS, # 传入新参数
# ...
)
就是这么简单!我们的训练流程现在已经由一个现代化的高效模型驱动了。
总结与展望
Congratulations!通过这次旅程,我们不仅亲手实现了一个融合了多种 SOTA 技术的大模型架构,更重要的是,我们理解了这些技术背后的动机和原理。
让我们回顾一下我们的成就:
- 用 RMSNorm 替换 LayerNorm,为计算减负。
- 用 RoPE 替换绝对位置编码,赋予模型更好的长文本理解能力。
- 用 SwiGLU 替换标准 FFN,让信息流动更智能。
- 用 GQA 和 FlashAttention 彻底重构了注意力层,打破了性能和内存的双重瓶颈。
我们现在手中的代码,虽然参数量不大,但其架构已经与 Llama、Mistral 等业界顶尖模型看齐。以此为基础,我们可以进一步探索量化、模型并行、更复杂的训练策略等更高级的主题。
大模型的学习之路漫长而有趣。希望这篇文章能成为各位攀登这座高峰时,一个坚实的垫脚石。
关注我的公众号不走丢
附录
GitHub链接:github.com/JimmysAIPG/…