揭秘GPT-4与LLaMA背后的加速黑科技:KV Cache、MQA、GQA、稀疏注意力与MoE全解析

0 阅读4分钟

本文深入讲解现代大语言模型的核心优化技术,包括KV Cache自回归加速、Multi-Query Attention(MQA)、Grouped-Query Attention(GQA)、稀疏注意力(Sparse Attention)和混合专家模型(Mixture of Experts, MoE)。通过数学原理、架构对比和PyTorch代码实现,帮助读者理解GPT-4、LLaMA、Mixtral等顶级模型的技术细节,掌握LLM推理加速与显存优化的工程实践。


一、为什么需要优化Transformer?

1.1 原始Transformer的性能瓶颈

graph TB
    subgraph 问题[三大瓶颈]
        P1["🐌 推理速度慢<br/>自回归逐词生成<br/>大量重复计算"]
        P2["💾 显存占用高<br/>KV矩阵随序列长度增长<br/>多头存储冗余"]
        P3["📏 序列长度受限<br/>O(n²)复杂度<br/>长文本处理困难"]
    end
    
    subgraph 解决方案
        S1["✅ KV Cache<br/>缓存已计算的KV"]
        S2["✅ MQA/GQA<br/>共享KV降低显存"]
        S3["✅ Sparse Attention<br/>稀疏注意力模式"]
    end
    
    P1 --> S1
    P2 --> S2
    P3 --> S3
    
    style P1 fill:#ffcdd2
    style P2 fill:#ffccbc
    style P3 fill:#ffab91
    style S1 fill:#a5d6a7
    style S2 fill:#81c784
    style S3 fill:#66bb6a

1.2 现代LLM采用的优化技术

模型KV CacheMQA/GQASparse AttnMoE上下文长度
GPT-32K
LLaMA4K
LLaMA2✅ GQA4K
GPT-4部分推测✅32K/128K
Mixtral 8x7B✅ GQA32K
Claude 3?200K

二、KV Cache:自回归加速的核心技术

2.1 自回归生成的重复计算问题

场景:GPT模型生成"我爱学习AI"

sequenceDiagram
    participant Input
    participant Model
    participant Output
    
    Note over Input,Output: Step 1: 生成"我"
    Input->>Model: [START]
    Model->>Output: "我"
    
    Note over Input,Output: Step 2: 生成"爱"
    Input->>Model: [START, 我]
    Note right of Model: ❌ 重新计算"我"的KV
    Model->>Output: "爱"
    
    Note over Input,Output: Step 3: 生成"学习"
    Input->>Model: [START, 我, 爱]
    Note right of Model: ❌ 重新计算"我""爱"的KV
    Model->>Output: "学习"
    
    Note over Input,Output: Step 4: 生成"AI"
    Input->>Model: [START, 我, 爱, 学习]
    Note right of Model: ❌ 重新计算所有历史KV
    Model->>Output: "AI"

问题分析:

  • 生成第1个词:计算1次KV
  • 生成第2个词:计算2次KV(1次重复)
  • 生成第3个词:计算3次KV(2次重复)
  • 生成第n个词:计算n次KV(n-1次重复)

总计算量: 1+2+3+...+n=n(n+1)2=O(n2)1 + 2 + 3 + ... + n = \frac{n(n+1)}{2} = O(n^2)

2.2 KV Cache的工作原理

核心思想:缓存已经计算过的Key和Value矩阵,新token只需计算自己的KV。

graph TB
    subgraph 无Cache[Without KV Cache]
        S1["Step 1<br/>计算: [START]"]
        S2["Step 2<br/>计算: [START, 我]<br/>❌ 重复计算START"]
        S3["Step 3<br/>计算: [START, 我, 爱]<br/>❌ 重复计算START,我"]
    end
    
    subgraph 有Cache[With KV Cache]
        C1["Step 1<br/>计算&缓存: [START]"]
        C2["Step 2<br/>✅ 读取: [START]<br/>计算&缓存: [我]"]
        C3["Step 3<br/>✅ 读取: [START, 我]<br/>计算&缓存: [爱]"]
    end
    
    S1 --> S2 --> S3
    C1 --> C2 --> C3
    
    style S2 fill:#ffcdd2
    style S3 fill:#ffcdd2
    style C2 fill:#a5d6a7
    style C3 fill:#a5d6a7

加速效果:

  • 无Cache: O(n2)O(n^2) 计算
  • 有Cache: O(n)O(n) 计算
  • 加速比: 生成100个token,加速约50倍!

2.3 KV Cache数学原理

标准Attention:

Attention(Qt,K1:t,V1:t)=softmax(QtK1:tTdk)V1:t\text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right)V_{1:t}

在第tt步:

  • QtQ_t: 当前token的Query (新计算)
  • K1:tK_{1:t}: 所有历史token的Key (1到t-1从缓存读取,t新计算)
  • V1:tV_{1:t}: 所有历史token的Value (同上)

缓存更新:

# Pseudo-code
cache_K = []  # 初始化KV缓存
cache_V = []

for t in range(max_len):
    # 1. 计算当前token的KV
    k_t = compute_key(x_t)
    v_t = compute_value(x_t)
    
    # 2. 追加到缓存
    cache_K.append(k_t)
    cache_V.append(v_t)
    
    # 3. 使用全部缓存计算注意力
    q_t = compute_query(x_t)
    attention = softmax(q_t @ cache_K.T) @ cache_V

2.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, x, cache=None, use_cache=False):
        """
        参数:
            x: [batch_size, seq_len, d_model]
            cache: {'key': [batch, n_heads, past_len, d_k],
                   'value': [batch, n_heads, past_len, d_k]}
            use_cache: 是否返回更新后的cache
        """
        batch_size, seq_len, _ = x.size()
        
        # 1. 计算当前输入的QKV
        Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. 如果有cache,拼接历史KV
        if cache is not None:
            K = torch.cat([cache['key'], K], dim=2)    # 拼接到seq_len维度
            V = torch.cat([cache['value'], V], dim=2)
        
        # 3. 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # 4. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_O(attn_output)
        
        # 5. 更新cache
        if use_cache:
            new_cache = {'key': K, 'value': V}
            return output, new_cache
        return output


# 使用示例:模拟自回归生成
d_model = 512
n_heads = 8
max_len = 10

mha = MultiHeadAttentionWithCache(d_model, n_heads)

# 初始化
cache = None
all_outputs = []

for t in range(max_len):
    # 当前token (实际中是上一步的输出)
    current_token = torch.randn(1, 1, d_model)  # [batch=1, seq_len=1, d_model]
    
    # 前向传播 with cache
    output, cache = mha(current_token, cache=cache, use_cache=True)
    all_outputs.append(output)
    
    print(f"Step {t+1}:")
    print(f"  Cache K shape: {cache['key'].shape}")
    print(f"  Cache V shape: {cache['value'].shape}")

# 输出示例:
# Step 1:
#   Cache K shape: torch.Size([1, 8, 1, 64])
#   Cache V shape: torch.Size([1, 8, 1, 64])
# Step 2:
#   Cache K shape: torch.Size([1, 8, 2, 64])  ← 长度递增
#   Cache V shape: torch.Size([1, 8, 2, 64])
# ...

2.5 KV Cache的显存成本

分析:对于单个样本

KV Cache Size=2×n_layers×n_heads×seq_len×d_k×sizeof(dtype)\text{KV Cache Size} = 2 \times \text{n\_layers} \times \text{n\_heads} \times \text{seq\_len} \times \text{d\_k} \times \text{sizeof(dtype)}

示例:LLaMA2-7B

  • n_layers = 32
  • n_heads = 32
  • seq_len = 4096
  • d_k = 128
  • dtype = float16 (2 bytes)
KV Cache=2×32×32×4096×128×2=2.1GB\text{KV Cache} = 2 \times 32 \times 32 \times 4096 \times 128 \times 2 = 2.1 \text{GB}

单个序列就需要2GB显存! 这就是为什么需要MQA/GQA优化。


三、Multi-Query Attention(MQA):共享KV的激进方案

3.1 MQA的动机

问题:在多头注意力中,每个头都有独立的KV矩阵,造成显存冗余。

graph TB
    subgraph 标准MHA[Multi-Head Attention]
        Q1["Q1"] --> H1["Head 1"]
        K1["K1"] --> H1
        V1["V1"] --> H1
        
        Q2["Q2"] --> H2["Head 2"]
        K2["K2"] --> H2
        V2["V2"] --> H2
        
        Qn["Qn"] --> Hn["Head n"]
        Kn["Kn"] --> Hn
        Vn["Vn"] --> Hn
    end
    
    subgraph MQA[Multi-Query Attention]
        Q1m["Q1"] --> H1m["Head 1"]
        SharedKV["共享 K, V"] --> H1m
        SharedKV --> H2m["Head 2"]
        SharedKV --> Hnm["Head n"]
        Q2m["Q2"] --> H2m
        Qnm["Qn"] --> Hnm
    end
    
    style SharedKV fill:#a5d6a7
    style K1 fill:#ffcdd2
    style K2 fill:#ffcdd2
    style Kn fill:#ffcdd2

核心思想:所有注意力头共享同一组Key和Value,只有Query独立。

3.2 MQA数学公式

标准MHA:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

MQA:

headi=Attention(QWiQ,KWK,VWV)\text{head}_i = \text{Attention}(QW_i^Q, KW^K, VW^V)

注意:WK,WVW^K, W^V 在所有头之间共享。

3.3 显存节省计算

参数量对比:

配置MHAMQA节省
Q权重h×dmodel×dkh \times d_{model} \times d_kh×dmodel×dkh \times d_{model} \times d_k0
K权重h×dmodel×dkh \times d_{model} \times d_kdmodel×dkd_{model} \times d_k(h1)/h×100%(h-1)/h \times 100\%
V权重h×dmodel×dkh \times d_{model} \times d_kdmodel×dkd_{model} \times d_k(h1)/h×100%(h-1)/h \times 100\%

示例(h=32):

  • MHA KV缓存: 2.1 GB
  • MQA KV缓存: 2.1/32 = 66 MB (节省96.9%!)

3.4 PyTorch实现

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 每个头独立的Query
        self.W_Q = nn.Linear(d_model, d_model)
        
        # 共享的Key和Value
        self.W_K = nn.Linear(d_model, self.d_k)  # 注意维度!
        self.W_V = nn.Linear(d_model, self.d_k)
        
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # 1. 计算多头Query
        Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
        Q = Q.transpose(1, 2)  # [batch, n_heads, seq_len, d_k]
        
        # 2. 计算共享的K和V
        K = self.W_K(x)  # [batch, seq_len, d_k]
        V = self.W_V(x)  # [batch, seq_len, d_k]
        
        # 扩展到所有头(通过broadcast)
        K = K.unsqueeze(1)  # [batch, 1, seq_len, d_k]
        V = V.unsqueeze(1)
        
        # 3. 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        # [batch, n_heads, seq_len, d_k]
        
        # 4. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_O(attn_output)
        
        return output


# 对比参数量
d_model = 512
n_heads = 8

mha = MultiHeadAttention(d_model, n_heads)
mqa = MultiQueryAttention(d_model, n_heads)

print(f"MHA 参数量: {sum(p.numel() for p in mha.parameters())}")
print(f"MQA 参数量: {sum(p.numel() for p in mqa.parameters())}")
# MHA 参数量: 1,050,624
# MQA 参数量: 820,224 (节省22%)

3.5 MQA的缺点

graph LR
    Pro["✅ 优点"] --> P1["显存占用大幅降低"]
    Pro --> P2["推理速度显著提升"]
    
    Con["❌ 缺点"] --> C1["表达能力下降"]
    Con --> C2["精度略有损失"]
    Con --> C3["多头冗余度太低"]
    
    style Pro fill:#a5d6a7
    style Con fill:#ffcdd2

实验数据(PaLM论文):

  • 推理速度: 提升1.5-2x
  • 模型质量: 下降约3-5%

四、Grouped-Query Attention(GQA):MHA与MQA的平衡

4.1 GQA的设计哲学

核心思想:将多个Query头分组,每组共享一对KV。

graph TB
    subgraph MHA[Multi-Head: h个独立KV]
        MHA_Heads["Head1 Head2 ... Head-h<br/>K1,V1 K2,V2 ... Kh,Vh"]
    end
    
    subgraph GQA[Grouped-Query: g组共享KV]
        GQA_Group1["组1: Head1,2,3,4<br/>共享 K1,V1"]
        GQA_Group2["组2: Head5,6,7,8<br/>共享 K2,V2"]
    end
    
    subgraph MQA[Multi-Query: 1组共享KV]
        MQA_All["所有Head<br/>共享 K,V"]
    end
    
    MHA -.折中方案.-> GQA
    GQA -.极端情况.-> MQA
    
    style MHA fill:#ffccbc
    style GQA fill:#fff9c4
    style MQA fill:#a5d6a7

4.2 GQA配置

数学关系:

  • Query头数: hh (如32)
  • KV组数: gg (如4或8)
  • 每组Query数: h/gh/g

常见配置:

模型Query头数KV组数每组头数显存节省
LLaMA2-7B328475%
LLaMA2-13B405887.5%
LLaMA2-70B648887.5%
Mixtral 8x7B328475%

4.3 GQA架构图

graph TB
    X["输入 X"] --> Linear["线性变换"]
    
    Linear --> Q["Query<br/>[h个头]"]
    Linear --> K["Key<br/>[g组]"]
    Linear --> V["Value<br/>[g组]"]
    
    subgraph 组1
        Q1["Q头1-4"] --> Attn1["注意力计算"]
        K1["K1"] --> Attn1
        V1["V1"] --> Attn1
    end
    
    subgraph 组2
        Q2["Q头5-8"] --> Attn2["注意力计算"]
        K2["K2"] --> Attn2
        V2["V2"] --> Attn2
    end
    
    Attn1 --> Concat["拼接"]
    Attn2 --> Concat
    Concat --> Output["输出"]
    
    style Q fill:#fff9c4
    style K fill:#a5d6a7
    style V fill:#81c784
    style Output fill:#c5e1a5

4.4 PyTorch实现

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_groups):
        """
        参数:
            d_model: 模型维度(如4096)
            n_heads: Query头数(如32)
            n_kv_groups: KV组数(如8)
        """
        super().__init__()
        assert n_heads % n_kv_groups == 0, "n_heads必须能被n_kv_groups整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_groups = n_kv_groups
        self.n_heads_per_group = n_heads // n_kv_groups
        self.d_k = d_model // n_heads
        
        # Query: 每个头独立
        self.W_Q = nn.Linear(d_model, d_model)
        
        # Key & Value: 每组一个
        self.W_K = nn.Linear(d_model, n_kv_groups * self.d_k)
        self.W_V = nn.Linear(d_model, n_kv_groups * self.d_k)
        
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # 1. 计算Q (所有头)
        Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
        Q = Q.transpose(1, 2)  # [batch, n_heads, seq_len, d_k]
        
        # 2. 计算K, V (每组一个)
        K = self.W_K(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
        V = self.W_V(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
        K = K.transpose(1, 2)  # [batch, n_kv_groups, seq_len, d_k]
        V = V.transpose(1, 2)
        
        # 3. 将KV复制到每组内的所有头
        K = K.repeat_interleave(self.n_heads_per_group, dim=1)
        V = V.repeat_interleave(self.n_heads_per_group, dim=1)
        # 现在 K, V: [batch, n_heads, seq_len, d_k]
        
        # 4. 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # 5. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_O(attn_output)
        
        return output


# 使用示例
d_model = 4096
n_heads = 32
n_kv_groups = 8  # LLaMA2-7B配置

gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)

x = torch.randn(2, 10, d_model)
output = gqa(x)

print(f"输入: {x.shape}")      # torch.Size([2, 10, 4096])
print(f"输出: {output.shape}")  # torch.Size([2, 10, 4096])

4.5 MHA vs MQA vs GQA对比

graph TB
    subgraph 性能对比
        Quality["模型质量<br/>(困惑度 Perplexity)"]
        Speed["推理速度<br/>(tokens/sec)"]
        Memory["显存占用<br/>(GB)"]
    end
    
    subgraph MHA评分
        Q_MHA["最好 ⭐⭐⭐⭐⭐"]
        S_MHA["最慢 ⭐⭐"]
        M_MHA["最高 ⭐"]
    end
    
    subgraph GQA评分
        Q_GQA["接近MHA ⭐⭐⭐⭐"]
        S_GQA["较快 ⭐⭐⭐⭐"]
        M_GQA["适中 ⭐⭐⭐"]
    end
    
    subgraph MQA评分
        Q_MQA["略低 ⭐⭐⭐"]
        S_MQA["最快 ⭐⭐⭐⭐⭐"]
        M_MQA["最低 ⭐⭐⭐⭐⭐"]
    end
    
    Quality --> Q_MHA
    Quality --> Q_GQA
    Quality --> Q_MQA
    
    Speed --> S_MHA
    Speed --> S_GQA
    Speed --> S_MQA
    
    Memory --> M_MHA
    Memory --> M_GQA
    Memory --> M_MQA
    
    style Q_GQA fill:#fff59d
    style S_GQA fill:#fff59d
    style M_GQA fill:#fff59d

实验数据(LLaMA2论文):

  • 质量: GQA-8 几乎等同于 MHA
  • 速度: GQA-8 比 MHA 快 1.3x
  • 显存: GQA-8 节省 75% KV缓存

五、稀疏注意力(Sparse Attention)

5.1 长序列的注意力复杂度问题

标准Attention的瓶颈:

复杂度=O(n2d)\text{复杂度} = O(n^2 d)

其中 nn 是序列长度,dd 是维度。

graph LR
    Seq["序列长度"] --> Comp["计算复杂度"]
    
    L1["1K tokens"] --> C1["O(1M)"]
    L2["10K tokens"] --> C2["O(100M)"]
    L3["100K tokens"] --> C3["O(10B)"]
    
    style L1 fill:#a5d6a7
    style L2 fill:#fff9c4
    style L3 fill:#ffcdd2

Claude 3处理200K上下文需要什么?

200K2=40billion operations per layer!200K^2 = 40 \text{billion operations per layer!}

5.2 稀疏注意力模式

核心思想:不是所有token都需要关注所有其他token。

graph TB
    subgraph Full[全注意力 O(n²)]
        F["每个token<br/>关注所有token"]
    end
    
    subgraph Sparse[稀疏注意力]
        S1["局部注意力<br/>Sliding Window"]
        S2["全局注意力<br/>Global Tokens"]
        S3["随机注意力<br/>Random Sampling"]
        S4["分块注意力<br/>Blocked"]
    end
    
    Full -.优化.-> Sparse
    
    style Full fill:#ffcdd2
    style S1 fill:#a5d6a7
    style S2 fill:#81c784
    style S3 fill:#66bb6a
    style S4 fill:#4caf50

5.3 常见稀疏注意力模式

(1) Sliding Window Attention

思想:每个token只关注前后固定窗口内的token。

graph LR
    subgraph 注意力矩阵
        T1["Token 1"] -.-> W1["窗口1-3"]
        T2["Token 2"] -.-> W2["窗口1-4"]
        T3["Token 3"] -.-> W3["窗口1-5"]
        T4["Token 4"] -.-> W4["窗口2-6"]
    end
    
    style T1 fill:#fff9c4
    style W1 fill:#a5d6a7

复杂度: O(n×w)O(n \times w),其中 ww 是窗口大小(如512)

实现:

def sliding_window_mask(seq_len, window_size):
    """
    生成滑动窗口mask
    """
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = 1
    return mask

# 示例
mask = sliding_window_mask(10, window_size=2)
print(mask)
# tensor([[1., 1., 1., 0., 0., ...],
#         [1., 1., 1., 1., 0., ...],
#         [1., 1., 1., 1., 1., ...],
#         ...])

(2) Global + Local Attention(Longformer模式)

思想:少数全局token关注所有,大部分token只做局部关注。

graph TB
    subgraph 全局Token
        G["CLS, SEP<br/>关注所有token"]
    end
    
    subgraph 局部Token
        L["普通token<br/>只关注窗口内"]
    end
    
    G -.全注意力.-> All["全部序列"]
    L -.局部.-> Window["小窗口"]
    
    style G fill:#ffeb3b
    style L fill:#90caf9

实现:

def longformer_mask(seq_len, window_size, global_indices):
    """
    Longformer注意力mask
    global_indices: 全局token的位置(如[0, 1])
    """
    # 基础:滑动窗口
    mask = sliding_window_mask(seq_len, window_size)
    
    # 全局token可以关注所有
    for idx in global_indices:
        mask[idx, :] = 1   # 该行全1
        mask[:, idx] = 1   # 该列全1
    
    return mask

(3) Sparse Transformer (分块注意力)

思想:将序列分块,块内全注意力,块间稀疏连接。

graph TB
    subgraph Block1[块1]
        B1_T1["Token 1-8"]
    end
    subgraph Block2[块2]
        B2_T1["Token 9-16"]
    end
    subgraph Block3[块3]
        B3_T1["Token 17-24"]
    end
    
    Block1 <-.块内全连接.-> Block1
    Block2 <-.块内全连接.-> Block2
    Block3 <-.块内全连接.-> Block3
    
    Block1 -.稀疏连接.-> Block2
    Block2 -.稀疏连接.-> Block3
    
    style Block1 fill:#e3f2fd
    style Block2 fill:#fff9c4
    style Block3 fill:#f3e5f5

5.4 FlashAttention: IO优化而非稀疏化

特殊说明:FlashAttention不改变注意力模式,而是优化GPU内存访问。

graph LR
    subgraph 标准Attention[标准实现]
        Step1["1. 计算QK^T<br/>写入HBM"]
        Step2["2. 读取,Softmax<br/>写回HBM"]
        Step3["3. 读取,乘V<br/>写回HBM"]
    end
    
    subgraph FlashAttn[FlashAttention]
        Fused["分块计算<br/>全程在SRAM<br/>减少HBM访问"]
    end
    
    Step1 --> Step2 --> Step3
    
    style Step1 fill:#ffccbc
    style Step2 fill:#ffccbc
    style Step3 fill:#ffccbc
    style Fused fill:#a5d6a7

加速效果:

  • 训练: 快2-4x
  • 长序列: 支持64K+上下文

六、混合专家模型(Mixture of Experts, MoE)

6.1 MoE的核心思想

问题:大模型参数多,但每次前向传播只需要激活部分参数。

graph TB
    Input["输入Token"] --> Router["路由网络<br/>决策选择专家"]
    
    Router -->|20%概率| E1["专家1<br/>数学推理"]
    Router -->|5%概率| E2["专家2<br/>代码生成"]
    Router -->|60%概率| E3["专家3<br/>通用知识"]
    Router -->|10%概率| E4["专家4<br/>创意写作"]
    Router -->|5%概率| En["专家N<br/>..."]
    
    E1 --> Combine["加权组合"]
    E2 --> Combine
    E3 --> Combine
    E4 --> Combine
    En --> Combine
    
    Combine --> Output["输出"]
    
    style Router fill:#fff59d
    style E3 fill:#a5d6a7
    style Combine fill:#90caf9

关键特点:

  1. 稀疏激活:每个token只激活Top-K个专家(如K=2)
  2. 参数共享:总参数量大,但实际计算量接近小模型
  3. 专业化:不同专家学习不同领域知识

6.2 MoE架构

graph TB
    X["输入 X"] --> SelfAttn["自注意力"]
    SelfAttn --> Norm1["LayerNorm"]
    
    Norm1 --> Router["路由网络<br/>Gating"]
    
    subgraph MoE层
        Router -->|权重w1| Expert1["FFN 专家1"]
        Router -->|权重w2| Expert2["FFN 专家2"]
        Router -->|权重0| Expert3["FFN 专家3<br/>未激活"]
        Router -->|权重0| ExpertN["FFN 专家N<br/>未激活"]
    end
    
    Expert1 --> Sum["加权求和<br/>w1·E1 + w2·E2"]
    Expert2 --> Sum
    
    Sum --> Norm2["LayerNorm"]
    Norm2 --> Output["输出"]
    
    style Router fill:#fff59d
    style Expert1 fill:#a5d6a7
    style Expert2 fill:#81c784
    style Expert3 fill:#e0e0e0
    style ExpertN fill:#e0e0e0

6.3 路由机制

Softmax路由:

G(x)=Softmax(xWg)G(x) = \text{Softmax}(x \cdot W_g)

Top-K选择:

Output=iTopK(G(x))G(x)iEi(x)\text{Output} = \sum_{i \in \text{TopK}(G(x))} G(x)_i \cdot E_i(x)

PyTorch实现:

class MoELayer(nn.Module):
    def __init__(self, d_model, d_ff, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 路由网络
        self.gate = nn.Linear(d_model, num_experts)
        
        # 专家网络(FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.ReLU(),
                nn.Linear(d_ff, d_model)
            )
            for _ in range(num_experts)
        ])
    
    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.size()
        
        # 1. 路由打分
        gate_logits = self.gate(x)  # [batch, seq_len, num_experts]
        
        # 2. 选择Top-K专家
        top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
        # top_k_indices: [batch, seq_len, top_k]
        
        # 3. Softmax归一化(只在Top-K上)
        top_k_gates = F.softmax(top_k_logits, dim=-1)
        # [batch, seq_len, top_k]
        
        # 4. 计算专家输出并加权求和
        output = torch.zeros_like(x)
        
        for k in range(self.top_k):
            # 获取当前专家索引
            expert_idx = top_k_indices[:, :, k]  # [batch, seq_len]
            gate_weight = top_k_gates[:, :, k]   # [batch, seq_len]
            
            # 批量处理(简化版,实际中需要更高效的实现)
            for i in range(self.num_experts):
                mask = (expert_idx == i)  # [batch, seq_len]
                if mask.any():
                    expert_output = self.experts[i](x)
                    output += expert_output * gate_weight.unsqueeze(-1) * mask.unsqueeze(-1)
        
        return output


# 使用示例
d_model = 512
d_ff = 2048
num_experts = 8
top_k = 2

moe = MoELayer(d_model, d_ff, num_experts, top_k)

x = torch.randn(2, 10, d_model)
output = moe(x)

print(f"输入: {x.shape}")      # torch.Size([2, 10, 512])
print(f"输出: {output.shape}")  # torch.Size([2, 10, 512])

6.4 实际案例:Mixtral 8x7B

架构特点:

  • 8个专家,每个7B参数
  • Top-2路由:每个token激活2个专家
  • 总参数: 47B (8×7B,但共享attention)
  • 激活参数: 13B (相当于13B模型的计算量)
graph TB
    Model["Mixtral 8x7B"] --> Params["总参数: 47B"]
    Model --> Active["激活参数: 13B"]
    Model --> Speed["推理速度 ≈ 13B模型"]
    Model --> Quality["性能接近 70B模型"]
    
    style Model fill:#fff59d
    style Speed fill:#a5d6a7
    style Quality fill:#81c784

性能数据:

  • 数学推理: 优于LLaMA2-70B
  • 代码生成: 接近GPT-3.5
  • 推理速度: 比70B快5x+

6.5 MoE的挑战

挑战说明解决方案
负载均衡某些专家被过度使用添加辅助损失函数
通信开销分布式训练时专家在不同GPU专家并行策略
泛化性专家过度专业化正则化技术

负载均衡损失:

Lbalance=αCV(expert_usage)L_{balance} = \alpha \cdot \text{CV}(\text{expert\_usage})

其中 CV 是变异系数,鼓励专家使用均匀。


七、技术对比与选择指南

7.1 综合对比表

技术加速比显存节省质量损失实现难度适用场景
KV Cache50x+0%0%所有自回归模型(必备)
MQA2x96%3-5%⭐⭐极致推理速度场景
GQA1.3x75%<1%⭐⭐推荐,平衡方案
Sparse Attn10x+50%+0-5%⭐⭐⭐⭐超长文本(100K+)
MoE5x70%0%⭐⭐⭐⭐⭐超大模型,计算受限

7.2 选择决策树

graph TD
    Start{需求是什么?} --> Q1{序列长度?}
    
    Q1 -->|<4K| Short[标准场景]
    Q1 -->|4K-32K| Medium[中长文本]
    Q1 -->|>32K| Long[超长文本]
    
    Short --> Q2{显存限制?}
    Q2 -->|宽松| Use_MHA[使用标准MHA<br/>+ KV Cache]
    Q2 -->|紧张| Use_GQA[使用GQA<br/>+ KV Cache]
    
    Medium --> Q3{质量要求?}
    Q3 -->|最高| MHA_Long[MHA + KV Cache]
    Q3 -->|平衡| GQA_Long[GQA + Sliding Window]
    
    Long --> Sparse[Sparse Attention<br/>必选方案]
    
    Start --> Q4{是否超大模型?}
    Q4 -->|>100B| Consider_MoE[考虑MoE架构]
    
    style Use_GQA fill:#fff59d
    style GQA_Long fill:#fff59d
    style Sparse fill:#a5d6a7
    style Consider_MoE fill:#81c784

7.3 工业界实践

OpenAI GPT系列:

  • GPT-3: MHA + KV Cache
  • GPT-3.5/4: 推测 MQA/GQA + Sparse + MoE

Meta LLaMA系列:

  • LLaMA: MHA + KV Cache
  • LLaMA2: GQA-8 + KV Cache (黄金组合)
  • LLaMA3: GQA + 更长上下文

Google PaLM/Gemini:

  • PaLM: MQA + KV Cache
  • PaLM2: MQA改进版

Anthropic Claude:

  • Claude 1/2: 推测 GQA + Sparse
  • Claude 3: Sparse Attention (200K上下文)

八、实战:构建一个优化的Transformer

完整代码

class OptimizedTransformerBlock(nn.Module):
    """
    集成GQA + KV Cache的优化Transformer Block
    """
    def __init__(self, d_model, n_heads, n_kv_groups, d_ff, dropout=0.1):
        super().__init__()
        
        # GQA
        self.gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, cache=None, use_cache=False):
        # Self-attention with cache
        attn_out, new_cache = self.gqa(x, cache=cache, use_cache=use_cache)
        x = self.norm1(x + self.dropout(attn_out))
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        
        if use_cache:
            return x, new_cache
        return x


# LLaMA2-7B配置
d_model = 4096
n_heads = 32
n_kv_groups = 8  # GQA-8
d_ff = 11008
n_layers = 32

# 构建完整模型
class OptimizedLLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            OptimizedTransformerBlock(d_model, n_heads, n_kv_groups, d_ff)
            for _ in range(n_layers)
        ])
        self.lm_head = nn.Linear(d_model, vocab_size)
    
    def forward(self, input_ids, caches=None, use_cache=False):
        x = self.embedding(input_ids)
        
        new_caches = []
        for i, layer in enumerate(self.layers):
            cache = caches[i] if caches else None
            if use_cache:
                x, new_cache = layer(x, cache=cache, use_cache=True)
                new_caches.append(new_cache)
            else:
                x = layer(x)
        
        logits = self.lm_head(x)
        
        if use_cache:
            return logits, new_caches
        return logits


# 使用示例
vocab_size = 32000
model = OptimizedLLM(vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers)

print(f"模型参数量: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
# 输出: 模型参数量: 6.74B (接近LLaMA2-7B)

九、总结与展望

9.1 核心技术总结

mindmap
  root((现代LLM优化))
    推理加速
      KV Cache
        缓存历史KV
        O(n²)→O(n)
      Flash Attention
        IO优化
        SRAM计算
    显存优化
      MQA
        共享KV
        节省96%
      GQA
        分组共享
        节省75%
    长文本
      Sparse Attention
        滑动窗口
        全局+局部
      RoPE
        相对位置编码
    超大模型
      MoE
        稀疏激活
        专家路由
      模型并行
        专家并行
        张量并行

9.2 未来趋势

1. 更长的上下文

  • 目标: 100万token上下文
  • 技术: 混合注意力模式、分层记忆

2. 更高效的架构

  • 线性Attention (RWKV, RetNet)
  • 状态空间模型 (Mamba)

3. 动态计算

  • 早停机制 (Early Exit)
  • 自适应计算 (Adaptive Computation)

4. 硬件协同优化

  • 定制芯片(TPU, Groq)
  • 混合精度(FP8, INT4)

十、练习与资源

练习题

1. 计算KV Cache节省

# 给定LLaMA2-13B配置,计算生成1000个token的KV Cache大小
# n_layers=40, n_heads=40, d_k=128, seq_len=1000

2. 实现Sliding Window Mask

def create_sliding_window_mask(seq_len, window_size):
    # TODO: 实现并可视化
    pass

3. 对比GQA不同配置

# 实验GQA-4 vs GQA-8 vs MHA的性能和显存

推荐资源

  1. 📄 论文:

  2. 💻 代码: