从零训练大模型之模型升级版搭建及训练(中):实现FlashAttention, GQA,RoPE, RMSNorm, SwiGLU

179 阅读10分钟

前言

我们在《从零训练大模型之模型搭建》这篇文章中,按照《Attention Is All You Need》的内容进行了模型实现,也训练出了一个pre-train模型。也在上一篇文章《从零训练大模型之模型升级版搭建及训练(上)》文章中,我们指出了Llama、Mistral 这些强大的开源大模型使用的新技术,但是该如何编写代码呢?

别担心,这篇文章就是为升级我们之前的模型而准备的!我们将带领大家,亲手将一个“教科书式”的 Transformer Decoder-only 模型,一步步升级为集成了 FlashAttention、GQA、RoPE、SwiGLU、RMSNorm 等前沿技术的“准现代”LLM 架构。

我们的起点:一个“经典但过时”的Transformer

在开始升级之前,让我们先看看我们的“旧装备”——一个根据《Attention Is All You Need》论文内容用PyTorch标准API构建的Decoder-only模型。它包含了以下几个经典组件:

  • nn.Embedding + Sinusoidal Positional Encoding:标准的词嵌入加上绝对位置编码,通过给每个位置一个固定的“坐标”来注入位置信息。
  • nn.MultiheadAttention:PyTorch 官方实现的多头自注意力,is_causal=True 确保了它只能看到过去的信息。
  • nn.LayerNorm + nn.ReLU Feed-Forward:标准的层归一化和基于 ReLU 的前馈网络。
  • Pre-Norm 架构:先进行归一化,再送入子层(注意力或前馈网络),有助于训练稳定。

这个模型是我们这段时间学习Transformer的起点和成果,但它在效率和性能上与Llama等现代模型存在代差。为什么呢?

  • 计算/内存瓶颈:标准注意力的计算和内存复杂度是序列长度的平方(O(N²)),在处理长文本时会迅速成为瓶颈。
  • 位置编码的局限:绝对位置编码在处理超过训练长度的序列时,泛化能力较差。
  • 组件效率:LayerNorm 和 ReLU 虽然经典,但已有计算更简单、效果可能更好的替代品。

我们的目标,就是用现代化的组件替换这些经典部件,让模型“鸟枪换炮”。

逐个击破:五大核心技术升级

现在,让我们开始对模型的五大核心部分进行现代化改造。

1. RMSNorm:更快、更简单的归一化

为什么需要它?

LayerNorm 是 Transformer 的标配,但它的计算涉及减去均值和除以标准差,略显复杂。RMSNorm (Root Mean Square Normalization) 提出了一种简化方案:只通过均方根对输入进行缩放,去掉了减去均值的步骤。

  • 公式对比:
    • LayerNorm: y=xE[x]Var[x]+ϵγ+βy = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} \cdot \gamma + \beta
    • RMSNorm: y=xE[x2]+ϵγy = \frac{x}{\sqrt{\mathrm{E}[x^2] + \epsilon}} \cdot \gamma

优势: 计算量更小,实验证明在同样效果下速度更快。Llama 系列模型就全面采用了 RMSNorm。

代码实现如下:

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # gamma 参数
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        # 计算均方根 (rsqrt是平方根倒数,更高效)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

我们用这个 RMSNorm 类替换掉模型中所有的 nn.LayerNorm。

2. RoPE:更优雅的旋转式位置编码

为什么需要它?

传统的正弦位置编码是“绝对”的,它给每个 token 一个固定的位置ID。而 RoPE (Rotary Positional Embedding) 是一种“相对”位置编码。它的核心思想非常精妙:位置信息不应该通过“加法”注入,而应该通过“旋转”注入。

想象一下,Query 和 Key 向量是二维平面上的点。RoPE 根据它们的位置 m 和 n,将它们分别旋转一个角度。当计算它们的点积(注意力分数)时,结果只与它们的相对位置 m-n 有关,而与绝对位置无关。这使得模型能更好地理解词与词之间的相对距离,并且在处理超长序列时具有更好的泛化性。

代码实现:

RoPE 的实现分为两步:预计算旋转角度和应用旋转。

# 1. 预计算旋转矩阵(用 cos 和 sin 表示)
class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int, base: int = 10000, device: Optional[torch.device] = None):
        super().__init__()
        # 计算不同维度的旋转频率
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # 预计算所有位置的 cos 和 sin 值
        t = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        
        # 缓存起来,避免重复计算
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x: torch.Tensor):
        seq_len = x.shape[-2]
        return self.cos_cached[:, :, :seq_len, ...], self.sin_cached[:, :, :seq_len, ...]

# 2. 应用旋转的辅助函数
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

在模型的前向传播中,我们会先生成 cos 和 sin 值,然后传递给每个注意力层,应用到 Query 和 Key 上。

3. SwiGLU:更智能的前馈网络

为什么需要它?

标准的前馈网络(FFN)通常是 Linear -> ReLU -> Linear。SwiGLU 是一种改进的 FFN 结构,它引入了门控机制。

  • 公式:SwiGLU(x, W, V, W2) = (Swish(x @ W) * (x @ V)) @ W2
    • Swish(x) = x * sigmoid(x),在 PyTorch 中是 F.silu。
    • x @ W 的结果通过 Swish 激活,而 x @ V 的结果充当一个“门”,决定了前一部分信息有多少可以通过。

优势: 这种门控机制让网络可以动态地控制信息流,实验表明它能带来显著的性能提升。PaLM、Llama 和 Mistral 等模型都采用了类似 SwiGLU 的结构。

代码实现:

class SwiGLUFeedForward(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, dropout: float = 0.0):
        super().__init__()
        self.w1 = nn.Linear(embed_dim, ff_dim, bias=False) # 门控
        self.w2 = nn.Linear(ff_dim, embed_dim, bias=False) # 输出
        self.w3 = nn.Linear(embed_dim, ff_dim, bias=False) # up projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

4. GQA & FlashAttention:为注意力机制装上核动力引擎

这是我们本次升级中最核心、最关键的一步!

GQA (Grouped-Query Attention)

  • 为什么需要它?
    • 多头注意力(MHA) 中,每个 Query 头都有自己独立的 Key 和 Value 头。
    • 在推理时,模型需要缓存所有 token 的 Key 和 Value(即 KV Cache)来加速生成。对于 MHA,KV Cache 的大小与头的数量成正比,非常消耗显存。
    • 多查询注意力(MQA) 提出让所有的 Query 头共享同一组 Key 和 Value 头,极大减少了 KV Cache,但可能导致性能下降。
    • 分组查询注意力(GQA) 是两者的完美折中:它将 Query 头分组,组内的 Query 头共享同一组 Key 和 Value 头。
    • 例如,如果有 8 个 Query 头和 2 个 KV 头,那么每 4 个 Query 头会共享一组 KV。这既大幅减少了 KV Cache,又保持了足够高的模型质量。Mistral 7B 就是 GQA 的著名应用案例。

FlashAttention

  • 为什么需要它?
    • 标准注意力的最大问题是需要计算并存储一个 (seq_len, seq_len) 大小的注意力矩阵。当序列很长时(如 8k、16k),这个矩阵会变得异常巨大,远远超出 GPU SRAM(片上高速缓存)的容量,导致频繁地与 HBM(高带宽显存)进行数据读写,而这个过程非常缓慢,成为性能瓶颈。
    • FlashAttention 是一种 I/O 感知的注意力算法。它巧妙地将计算过程分块(Tiling),在 GPU 的 SRAM 中完成每个小块的注意力计算,而无需将整个巨大的注意力矩阵写入 HBM。
    • 结果:在不改变数学计算结果的前提下,实现了数量级的加速和显存节省。可以说,没有 FlashAttention,就没有今天的大模型长文本时代。

代码实现:

我们将 GQA 和 FlashAttention 结合在一个模块中。

# 引入 flash_attn
try:
    from flash_attn import flash_attn_func
    FLASH_ATTENTION_AVAILABLE = True
except ImportError:
    FLASH_ATTENTION_AVAILABLE = False
    from torch.nn.functional import scaled_dot_product_attention

class FlashGQA(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, num_kv_heads: int, dropout: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_heads
        self.num_q_per_kv = num_heads // num_kv_heads

        self.wq = nn.Linear(embed_dim, embed_dim, bias=False)
        self.wk = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(embed_dim, embed_dim, bias=False)
        self.dropout = dropout

    def forward(self, x, freqs_cos, freqs_sin, key_padding_mask=None):
        bsz, seq_len, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        # Reshape to (bs, seq_len, num_heads, head_dim)
        xq = xq.view(bsz, seq_len, self.num_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.num_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.num_kv_heads, self.head_dim)

        # 应用 RoPE
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)

        # FlashAttention
        if FLASH_ATTENTION_AVAILABLE:
            # GQA: 重复 K/V 头以匹配 Q 头
            keys_repeated = xk.repeat_interleave(self.num_q_per_kv, dim=2)
            values_repeated = xv.repeat_interleave(self.num_q_per_kv, dim=2)
            attn_output = flash_attn_func(
                q=xq, k=keys_repeated, v=values_repeated, 
                dropout_p=self.dropout if self.training else 0.0, causal=True
            )
        else: # Fallback to PyTorch's implementation
            # ... (备用方案代码,详见完整代码)
            # 同样需要重复K/V头
            xq = xq.transpose(1, 2)
            xk = xk.transpose(1, 2).repeat_interleave(self.num_q_per_kv, dim=1)
            xv = xv.transpose(1, 2).repeat_interleave(self.num_q_per_kv, dim=1)
            attn_output = scaled_dot_product_attention(
                xq, xk, xv, is_causal=True, 
                dropout_p=self.dropout if self.training else 0.0
            ).transpose(1, 2)

        attn_output = attn_output.view(bsz, seq_len, -1)
        return self.wo(attn_output)

终极合体:组装我们的现代化 LLM

现在,我们已经打造好了所有先进的组件。是时候把它们组装起来,构成我们最终的、现代化的 DecoderOnlyTransformer 了!

升级后的 DecoderBlock:

class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, num_kv_heads, ff_dim, dropout):
        super().__init__()
        # 使用新组件
        self.self_attention = FlashGQA(embed_dim, num_heads, num_kv_heads, dropout)
        self.feed_forward = SwiGLUFeedForward(embed_dim, ff_dim, dropout)
        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)

    def forward(self, x, freqs_cos, freqs_sin, key_padding_mask=None):
        # 结构不变,但内部调用已全部升级
        residual = x
        h = self.norm1(x)
        h = self.self_attention(h, freqs_cos, freqs_sin, key_padding_mask)
        x = residual + h

        residual = x
        h = self.norm2(x)
        h = self.feed_forward(h)
        x = residual + h
        
        return x

最终的 DecoderOnlyTransformer:

主模型的变化在于:

  1. 初始化:不再需要 PositionalEncoding,取而代之的是 RotaryEmbedding。
  2. forward 流程:在进入 DecoderBlock 循环前,先计算好 RoPE 需要的 freqs_cos 和 freqs_sin,然后将它们传递给每一层。
  3. 最终输出:在输出到 output_layer 之前,增加一个 RMSNorm,这是 Llama 等模型的常见做法。

使用指南

将如上内容替换到MiniLlmsModel文件中,即可完成升级。若想获取完整的模型代码,可以参考附录的GitHub连接。

如何使用

在我们的训练脚本(如 train_pretrain.py)中,我们只需要做两件小事:

  1. 从 MiniLlmsModel_optimized 导入 DecoderOnlyTransformer。
  2. 在实例化模型时,传入新的 num_kv_heads 参数。
from MiniLlmsModel_optimized import DecoderOnlyTransformer

# 定义 GQA 参数
NUM_KV_HEADS = 4 # 例如,8个Query头,4个KV头
FF_DIM = int(2/3 * 4 * EMBED_DIM) # SwiGLU 推荐维度

model = DecoderOnlyTransformer(
    # ... 其他参数 ...
    num_heads=NUM_HEADS,
    num_kv_heads=NUM_KV_HEADS, # 传入新参数
    # ...
)

就是这么简单!我们的训练流程现在已经由一个现代化的高效模型驱动了。

总结与展望

Congratulations!通过这次旅程,我们不仅亲手实现了一个融合了多种 SOTA 技术的大模型架构,更重要的是,我们理解了这些技术背后的动机和原理。

让我们回顾一下我们的成就:

  • 用 RMSNorm 替换 LayerNorm,为计算减负。
  • 用 RoPE 替换绝对位置编码,赋予模型更好的长文本理解能力。
  • 用 SwiGLU 替换标准 FFN,让信息流动更智能。
  • 用 GQA 和 FlashAttention 彻底重构了注意力层,打破了性能和内存的双重瓶颈。

我们现在手中的代码,虽然参数量不大,但其架构已经与 Llama、Mistral 等业界顶尖模型看齐。以此为基础,我们可以进一步探索量化、模型并行、更复杂的训练策略等更高级的主题。

大模型的学习之路漫长而有趣。希望这篇文章能成为各位攀登这座高峰时,一个坚实的垫脚石。


关注我的公众号不走丢

附录

GitHub链接:github.com/JimmysAIPG/…