本文深入讲解现代大语言模型的核心优化技术,包括KV Cache自回归加速、Multi-Query Attention(MQA)、Grouped-Query Attention(GQA)、稀疏注意力(Sparse Attention)和混合专家模型(Mixture of Experts, MoE)。通过数学原理、架构对比和PyTorch代码实现,帮助读者理解GPT-4、LLaMA、Mixtral等顶级模型的技术细节,掌握LLM推理加速与显存优化的工程实践。
一、为什么需要优化Transformer?
1.1 原始Transformer的性能瓶颈
graph TB
subgraph 问题[三大瓶颈]
P1["🐌 推理速度慢<br/>自回归逐词生成<br/>大量重复计算"]
P2["💾 显存占用高<br/>KV矩阵随序列长度增长<br/>多头存储冗余"]
P3["📏 序列长度受限<br/>O(n²)复杂度<br/>长文本处理困难"]
end
subgraph 解决方案
S1["✅ KV Cache<br/>缓存已计算的KV"]
S2["✅ MQA/GQA<br/>共享KV降低显存"]
S3["✅ Sparse Attention<br/>稀疏注意力模式"]
end
P1 --> S1
P2 --> S2
P3 --> S3
style P1 fill:#ffcdd2
style P2 fill:#ffccbc
style P3 fill:#ffab91
style S1 fill:#a5d6a7
style S2 fill:#81c784
style S3 fill:#66bb6a
1.2 现代LLM采用的优化技术
| 模型 | KV Cache | MQA/GQA | Sparse Attn | MoE | 上下文长度 |
|---|---|---|---|---|---|
| GPT-3 | ✅ | ❌ | ❌ | ❌ | 2K |
| LLaMA | ✅ | ❌ | ❌ | ❌ | 4K |
| LLaMA2 | ✅ | ✅ GQA | ❌ | ❌ | 4K |
| GPT-4 | ✅ | ✅ | 部分 | 推测✅ | 32K/128K |
| Mixtral 8x7B | ✅ | ✅ GQA | ❌ | ✅ | 32K |
| Claude 3 | ✅ | ✅ | ✅ | ? | 200K |
二、KV Cache:自回归加速的核心技术
2.1 自回归生成的重复计算问题
场景:GPT模型生成"我爱学习AI"
sequenceDiagram
participant Input
participant Model
participant Output
Note over Input,Output: Step 1: 生成"我"
Input->>Model: [START]
Model->>Output: "我"
Note over Input,Output: Step 2: 生成"爱"
Input->>Model: [START, 我]
Note right of Model: ❌ 重新计算"我"的KV
Model->>Output: "爱"
Note over Input,Output: Step 3: 生成"学习"
Input->>Model: [START, 我, 爱]
Note right of Model: ❌ 重新计算"我""爱"的KV
Model->>Output: "学习"
Note over Input,Output: Step 4: 生成"AI"
Input->>Model: [START, 我, 爱, 学习]
Note right of Model: ❌ 重新计算所有历史KV
Model->>Output: "AI"
问题分析:
- 生成第1个词:计算1次KV
- 生成第2个词:计算2次KV(1次重复)
- 生成第3个词:计算3次KV(2次重复)
- 生成第n个词:计算n次KV(n-1次重复)
总计算量:
2.2 KV Cache的工作原理
核心思想:缓存已经计算过的Key和Value矩阵,新token只需计算自己的KV。
graph TB
subgraph 无Cache[Without KV Cache]
S1["Step 1<br/>计算: [START]"]
S2["Step 2<br/>计算: [START, 我]<br/>❌ 重复计算START"]
S3["Step 3<br/>计算: [START, 我, 爱]<br/>❌ 重复计算START,我"]
end
subgraph 有Cache[With KV Cache]
C1["Step 1<br/>计算&缓存: [START]"]
C2["Step 2<br/>✅ 读取: [START]<br/>计算&缓存: [我]"]
C3["Step 3<br/>✅ 读取: [START, 我]<br/>计算&缓存: [爱]"]
end
S1 --> S2 --> S3
C1 --> C2 --> C3
style S2 fill:#ffcdd2
style S3 fill:#ffcdd2
style C2 fill:#a5d6a7
style C3 fill:#a5d6a7
加速效果:
- 无Cache: 计算
- 有Cache: 计算
- 加速比: 生成100个token,加速约50倍!
2.3 KV Cache数学原理
标准Attention:
在第步:
- : 当前token的Query (新计算)
- : 所有历史token的Key (1到t-1从缓存读取,t新计算)
- : 所有历史token的Value (同上)
缓存更新:
# Pseudo-code
cache_K = [] # 初始化KV缓存
cache_V = []
for t in range(max_len):
# 1. 计算当前token的KV
k_t = compute_key(x_t)
v_t = compute_value(x_t)
# 2. 追加到缓存
cache_K.append(k_t)
cache_V.append(v_t)
# 3. 使用全部缓存计算注意力
q_t = compute_query(x_t)
attention = softmax(q_t @ cache_K.T) @ cache_V
2.4 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttentionWithCache(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x, cache=None, use_cache=False):
"""
参数:
x: [batch_size, seq_len, d_model]
cache: {'key': [batch, n_heads, past_len, d_k],
'value': [batch, n_heads, past_len, d_k]}
use_cache: 是否返回更新后的cache
"""
batch_size, seq_len, _ = x.size()
# 1. 计算当前输入的QKV
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# 2. 如果有cache,拼接历史KV
if cache is not None:
K = torch.cat([cache['key'], K], dim=2) # 拼接到seq_len维度
V = torch.cat([cache['value'], V], dim=2)
# 3. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
# 5. 更新cache
if use_cache:
new_cache = {'key': K, 'value': V}
return output, new_cache
return output
# 使用示例:模拟自回归生成
d_model = 512
n_heads = 8
max_len = 10
mha = MultiHeadAttentionWithCache(d_model, n_heads)
# 初始化
cache = None
all_outputs = []
for t in range(max_len):
# 当前token (实际中是上一步的输出)
current_token = torch.randn(1, 1, d_model) # [batch=1, seq_len=1, d_model]
# 前向传播 with cache
output, cache = mha(current_token, cache=cache, use_cache=True)
all_outputs.append(output)
print(f"Step {t+1}:")
print(f" Cache K shape: {cache['key'].shape}")
print(f" Cache V shape: {cache['value'].shape}")
# 输出示例:
# Step 1:
# Cache K shape: torch.Size([1, 8, 1, 64])
# Cache V shape: torch.Size([1, 8, 1, 64])
# Step 2:
# Cache K shape: torch.Size([1, 8, 2, 64]) ← 长度递增
# Cache V shape: torch.Size([1, 8, 2, 64])
# ...
2.5 KV Cache的显存成本
分析:对于单个样本
示例:LLaMA2-7B
- n_layers = 32
- n_heads = 32
- seq_len = 4096
- d_k = 128
- dtype = float16 (2 bytes)
单个序列就需要2GB显存! 这就是为什么需要MQA/GQA优化。
三、Multi-Query Attention(MQA):共享KV的激进方案
3.1 MQA的动机
问题:在多头注意力中,每个头都有独立的KV矩阵,造成显存冗余。
graph TB
subgraph 标准MHA[Multi-Head Attention]
Q1["Q1"] --> H1["Head 1"]
K1["K1"] --> H1
V1["V1"] --> H1
Q2["Q2"] --> H2["Head 2"]
K2["K2"] --> H2
V2["V2"] --> H2
Qn["Qn"] --> Hn["Head n"]
Kn["Kn"] --> Hn
Vn["Vn"] --> Hn
end
subgraph MQA[Multi-Query Attention]
Q1m["Q1"] --> H1m["Head 1"]
SharedKV["共享 K, V"] --> H1m
SharedKV --> H2m["Head 2"]
SharedKV --> Hnm["Head n"]
Q2m["Q2"] --> H2m
Qnm["Qn"] --> Hnm
end
style SharedKV fill:#a5d6a7
style K1 fill:#ffcdd2
style K2 fill:#ffcdd2
style Kn fill:#ffcdd2
核心思想:所有注意力头共享同一组Key和Value,只有Query独立。
3.2 MQA数学公式
标准MHA:
MQA:
注意: 在所有头之间共享。
3.3 显存节省计算
参数量对比:
| 配置 | MHA | MQA | 节省 |
|---|---|---|---|
| Q权重 | 0 | ||
| K权重 | |||
| V权重 |
示例(h=32):
- MHA KV缓存: 2.1 GB
- MQA KV缓存: 2.1/32 = 66 MB (节省96.9%!)
3.4 PyTorch实现
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 每个头独立的Query
self.W_Q = nn.Linear(d_model, d_model)
# 共享的Key和Value
self.W_K = nn.Linear(d_model, self.d_k) # 注意维度!
self.W_V = nn.Linear(d_model, self.d_k)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 计算多头Query
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, n_heads, seq_len, d_k]
# 2. 计算共享的K和V
K = self.W_K(x) # [batch, seq_len, d_k]
V = self.W_V(x) # [batch, seq_len, d_k]
# 扩展到所有头(通过broadcast)
K = K.unsqueeze(1) # [batch, 1, seq_len, d_k]
V = V.unsqueeze(1)
# 3. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# [batch, n_heads, seq_len, d_k]
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
return output
# 对比参数量
d_model = 512
n_heads = 8
mha = MultiHeadAttention(d_model, n_heads)
mqa = MultiQueryAttention(d_model, n_heads)
print(f"MHA 参数量: {sum(p.numel() for p in mha.parameters())}")
print(f"MQA 参数量: {sum(p.numel() for p in mqa.parameters())}")
# MHA 参数量: 1,050,624
# MQA 参数量: 820,224 (节省22%)
3.5 MQA的缺点
graph LR
Pro["✅ 优点"] --> P1["显存占用大幅降低"]
Pro --> P2["推理速度显著提升"]
Con["❌ 缺点"] --> C1["表达能力下降"]
Con --> C2["精度略有损失"]
Con --> C3["多头冗余度太低"]
style Pro fill:#a5d6a7
style Con fill:#ffcdd2
实验数据(PaLM论文):
- 推理速度: 提升1.5-2x
- 模型质量: 下降约3-5%
四、Grouped-Query Attention(GQA):MHA与MQA的平衡
4.1 GQA的设计哲学
核心思想:将多个Query头分组,每组共享一对KV。
graph TB
subgraph MHA[Multi-Head: h个独立KV]
MHA_Heads["Head1 Head2 ... Head-h<br/>K1,V1 K2,V2 ... Kh,Vh"]
end
subgraph GQA[Grouped-Query: g组共享KV]
GQA_Group1["组1: Head1,2,3,4<br/>共享 K1,V1"]
GQA_Group2["组2: Head5,6,7,8<br/>共享 K2,V2"]
end
subgraph MQA[Multi-Query: 1组共享KV]
MQA_All["所有Head<br/>共享 K,V"]
end
MHA -.折中方案.-> GQA
GQA -.极端情况.-> MQA
style MHA fill:#ffccbc
style GQA fill:#fff9c4
style MQA fill:#a5d6a7
4.2 GQA配置
数学关系:
- Query头数: (如32)
- KV组数: (如4或8)
- 每组Query数:
常见配置:
| 模型 | Query头数 | KV组数 | 每组头数 | 显存节省 |
|---|---|---|---|---|
| LLaMA2-7B | 32 | 8 | 4 | 75% |
| LLaMA2-13B | 40 | 5 | 8 | 87.5% |
| LLaMA2-70B | 64 | 8 | 8 | 87.5% |
| Mixtral 8x7B | 32 | 8 | 4 | 75% |
4.3 GQA架构图
graph TB
X["输入 X"] --> Linear["线性变换"]
Linear --> Q["Query<br/>[h个头]"]
Linear --> K["Key<br/>[g组]"]
Linear --> V["Value<br/>[g组]"]
subgraph 组1
Q1["Q头1-4"] --> Attn1["注意力计算"]
K1["K1"] --> Attn1
V1["V1"] --> Attn1
end
subgraph 组2
Q2["Q头5-8"] --> Attn2["注意力计算"]
K2["K2"] --> Attn2
V2["V2"] --> Attn2
end
Attn1 --> Concat["拼接"]
Attn2 --> Concat
Concat --> Output["输出"]
style Q fill:#fff9c4
style K fill:#a5d6a7
style V fill:#81c784
style Output fill:#c5e1a5
4.4 PyTorch实现
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_kv_groups):
"""
参数:
d_model: 模型维度(如4096)
n_heads: Query头数(如32)
n_kv_groups: KV组数(如8)
"""
super().__init__()
assert n_heads % n_kv_groups == 0, "n_heads必须能被n_kv_groups整除"
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_groups = n_kv_groups
self.n_heads_per_group = n_heads // n_kv_groups
self.d_k = d_model // n_heads
# Query: 每个头独立
self.W_Q = nn.Linear(d_model, d_model)
# Key & Value: 每组一个
self.W_K = nn.Linear(d_model, n_kv_groups * self.d_k)
self.W_V = nn.Linear(d_model, n_kv_groups * self.d_k)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 计算Q (所有头)
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, n_heads, seq_len, d_k]
# 2. 计算K, V (每组一个)
K = self.W_K(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
V = self.W_V(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
K = K.transpose(1, 2) # [batch, n_kv_groups, seq_len, d_k]
V = V.transpose(1, 2)
# 3. 将KV复制到每组内的所有头
K = K.repeat_interleave(self.n_heads_per_group, dim=1)
V = V.repeat_interleave(self.n_heads_per_group, dim=1)
# 现在 K, V: [batch, n_heads, seq_len, d_k]
# 4. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 5. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
return output
# 使用示例
d_model = 4096
n_heads = 32
n_kv_groups = 8 # LLaMA2-7B配置
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
x = torch.randn(2, 10, d_model)
output = gqa(x)
print(f"输入: {x.shape}") # torch.Size([2, 10, 4096])
print(f"输出: {output.shape}") # torch.Size([2, 10, 4096])
4.5 MHA vs MQA vs GQA对比
graph TB
subgraph 性能对比
Quality["模型质量<br/>(困惑度 Perplexity)"]
Speed["推理速度<br/>(tokens/sec)"]
Memory["显存占用<br/>(GB)"]
end
subgraph MHA评分
Q_MHA["最好 ⭐⭐⭐⭐⭐"]
S_MHA["最慢 ⭐⭐"]
M_MHA["最高 ⭐"]
end
subgraph GQA评分
Q_GQA["接近MHA ⭐⭐⭐⭐"]
S_GQA["较快 ⭐⭐⭐⭐"]
M_GQA["适中 ⭐⭐⭐"]
end
subgraph MQA评分
Q_MQA["略低 ⭐⭐⭐"]
S_MQA["最快 ⭐⭐⭐⭐⭐"]
M_MQA["最低 ⭐⭐⭐⭐⭐"]
end
Quality --> Q_MHA
Quality --> Q_GQA
Quality --> Q_MQA
Speed --> S_MHA
Speed --> S_GQA
Speed --> S_MQA
Memory --> M_MHA
Memory --> M_GQA
Memory --> M_MQA
style Q_GQA fill:#fff59d
style S_GQA fill:#fff59d
style M_GQA fill:#fff59d
实验数据(LLaMA2论文):
- 质量: GQA-8 几乎等同于 MHA
- 速度: GQA-8 比 MHA 快 1.3x
- 显存: GQA-8 节省 75% KV缓存
五、稀疏注意力(Sparse Attention)
5.1 长序列的注意力复杂度问题
标准Attention的瓶颈:
其中 是序列长度, 是维度。
graph LR
Seq["序列长度"] --> Comp["计算复杂度"]
L1["1K tokens"] --> C1["O(1M)"]
L2["10K tokens"] --> C2["O(100M)"]
L3["100K tokens"] --> C3["O(10B)"]
style L1 fill:#a5d6a7
style L2 fill:#fff9c4
style L3 fill:#ffcdd2
Claude 3处理200K上下文需要什么?
5.2 稀疏注意力模式
核心思想:不是所有token都需要关注所有其他token。
graph TB
subgraph Full[全注意力 O(n²)]
F["每个token<br/>关注所有token"]
end
subgraph Sparse[稀疏注意力]
S1["局部注意力<br/>Sliding Window"]
S2["全局注意力<br/>Global Tokens"]
S3["随机注意力<br/>Random Sampling"]
S4["分块注意力<br/>Blocked"]
end
Full -.优化.-> Sparse
style Full fill:#ffcdd2
style S1 fill:#a5d6a7
style S2 fill:#81c784
style S3 fill:#66bb6a
style S4 fill:#4caf50
5.3 常见稀疏注意力模式
(1) Sliding Window Attention
思想:每个token只关注前后固定窗口内的token。
graph LR
subgraph 注意力矩阵
T1["Token 1"] -.-> W1["窗口1-3"]
T2["Token 2"] -.-> W2["窗口1-4"]
T3["Token 3"] -.-> W3["窗口1-5"]
T4["Token 4"] -.-> W4["窗口2-6"]
end
style T1 fill:#fff9c4
style W1 fill:#a5d6a7
复杂度: ,其中 是窗口大小(如512)
实现:
def sliding_window_mask(seq_len, window_size):
"""
生成滑动窗口mask
"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
mask[i, start:end] = 1
return mask
# 示例
mask = sliding_window_mask(10, window_size=2)
print(mask)
# tensor([[1., 1., 1., 0., 0., ...],
# [1., 1., 1., 1., 0., ...],
# [1., 1., 1., 1., 1., ...],
# ...])
(2) Global + Local Attention(Longformer模式)
思想:少数全局token关注所有,大部分token只做局部关注。
graph TB
subgraph 全局Token
G["CLS, SEP<br/>关注所有token"]
end
subgraph 局部Token
L["普通token<br/>只关注窗口内"]
end
G -.全注意力.-> All["全部序列"]
L -.局部.-> Window["小窗口"]
style G fill:#ffeb3b
style L fill:#90caf9
实现:
def longformer_mask(seq_len, window_size, global_indices):
"""
Longformer注意力mask
global_indices: 全局token的位置(如[0, 1])
"""
# 基础:滑动窗口
mask = sliding_window_mask(seq_len, window_size)
# 全局token可以关注所有
for idx in global_indices:
mask[idx, :] = 1 # 该行全1
mask[:, idx] = 1 # 该列全1
return mask
(3) Sparse Transformer (分块注意力)
思想:将序列分块,块内全注意力,块间稀疏连接。
graph TB
subgraph Block1[块1]
B1_T1["Token 1-8"]
end
subgraph Block2[块2]
B2_T1["Token 9-16"]
end
subgraph Block3[块3]
B3_T1["Token 17-24"]
end
Block1 <-.块内全连接.-> Block1
Block2 <-.块内全连接.-> Block2
Block3 <-.块内全连接.-> Block3
Block1 -.稀疏连接.-> Block2
Block2 -.稀疏连接.-> Block3
style Block1 fill:#e3f2fd
style Block2 fill:#fff9c4
style Block3 fill:#f3e5f5
5.4 FlashAttention: IO优化而非稀疏化
特殊说明:FlashAttention不改变注意力模式,而是优化GPU内存访问。
graph LR
subgraph 标准Attention[标准实现]
Step1["1. 计算QK^T<br/>写入HBM"]
Step2["2. 读取,Softmax<br/>写回HBM"]
Step3["3. 读取,乘V<br/>写回HBM"]
end
subgraph FlashAttn[FlashAttention]
Fused["分块计算<br/>全程在SRAM<br/>减少HBM访问"]
end
Step1 --> Step2 --> Step3
style Step1 fill:#ffccbc
style Step2 fill:#ffccbc
style Step3 fill:#ffccbc
style Fused fill:#a5d6a7
加速效果:
- 训练: 快2-4x
- 长序列: 支持64K+上下文
六、混合专家模型(Mixture of Experts, MoE)
6.1 MoE的核心思想
问题:大模型参数多,但每次前向传播只需要激活部分参数。
graph TB
Input["输入Token"] --> Router["路由网络<br/>决策选择专家"]
Router -->|20%概率| E1["专家1<br/>数学推理"]
Router -->|5%概率| E2["专家2<br/>代码生成"]
Router -->|60%概率| E3["专家3<br/>通用知识"]
Router -->|10%概率| E4["专家4<br/>创意写作"]
Router -->|5%概率| En["专家N<br/>..."]
E1 --> Combine["加权组合"]
E2 --> Combine
E3 --> Combine
E4 --> Combine
En --> Combine
Combine --> Output["输出"]
style Router fill:#fff59d
style E3 fill:#a5d6a7
style Combine fill:#90caf9
关键特点:
- 稀疏激活:每个token只激活Top-K个专家(如K=2)
- 参数共享:总参数量大,但实际计算量接近小模型
- 专业化:不同专家学习不同领域知识
6.2 MoE架构
graph TB
X["输入 X"] --> SelfAttn["自注意力"]
SelfAttn --> Norm1["LayerNorm"]
Norm1 --> Router["路由网络<br/>Gating"]
subgraph MoE层
Router -->|权重w1| Expert1["FFN 专家1"]
Router -->|权重w2| Expert2["FFN 专家2"]
Router -->|权重0| Expert3["FFN 专家3<br/>未激活"]
Router -->|权重0| ExpertN["FFN 专家N<br/>未激活"]
end
Expert1 --> Sum["加权求和<br/>w1·E1 + w2·E2"]
Expert2 --> Sum
Sum --> Norm2["LayerNorm"]
Norm2 --> Output["输出"]
style Router fill:#fff59d
style Expert1 fill:#a5d6a7
style Expert2 fill:#81c784
style Expert3 fill:#e0e0e0
style ExpertN fill:#e0e0e0
6.3 路由机制
Softmax路由:
Top-K选择:
PyTorch实现:
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# 路由网络
self.gate = nn.Linear(d_model, num_experts)
# 专家网络(FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
for _ in range(num_experts)
])
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.size()
# 1. 路由打分
gate_logits = self.gate(x) # [batch, seq_len, num_experts]
# 2. 选择Top-K专家
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
# top_k_indices: [batch, seq_len, top_k]
# 3. Softmax归一化(只在Top-K上)
top_k_gates = F.softmax(top_k_logits, dim=-1)
# [batch, seq_len, top_k]
# 4. 计算专家输出并加权求和
output = torch.zeros_like(x)
for k in range(self.top_k):
# 获取当前专家索引
expert_idx = top_k_indices[:, :, k] # [batch, seq_len]
gate_weight = top_k_gates[:, :, k] # [batch, seq_len]
# 批量处理(简化版,实际中需要更高效的实现)
for i in range(self.num_experts):
mask = (expert_idx == i) # [batch, seq_len]
if mask.any():
expert_output = self.experts[i](x)
output += expert_output * gate_weight.unsqueeze(-1) * mask.unsqueeze(-1)
return output
# 使用示例
d_model = 512
d_ff = 2048
num_experts = 8
top_k = 2
moe = MoELayer(d_model, d_ff, num_experts, top_k)
x = torch.randn(2, 10, d_model)
output = moe(x)
print(f"输入: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出: {output.shape}") # torch.Size([2, 10, 512])
6.4 实际案例:Mixtral 8x7B
架构特点:
- 8个专家,每个7B参数
- Top-2路由:每个token激活2个专家
- 总参数: 47B (8×7B,但共享attention)
- 激活参数: 13B (相当于13B模型的计算量)
graph TB
Model["Mixtral 8x7B"] --> Params["总参数: 47B"]
Model --> Active["激活参数: 13B"]
Model --> Speed["推理速度 ≈ 13B模型"]
Model --> Quality["性能接近 70B模型"]
style Model fill:#fff59d
style Speed fill:#a5d6a7
style Quality fill:#81c784
性能数据:
- 数学推理: 优于LLaMA2-70B
- 代码生成: 接近GPT-3.5
- 推理速度: 比70B快5x+
6.5 MoE的挑战
| 挑战 | 说明 | 解决方案 |
|---|---|---|
| 负载均衡 | 某些专家被过度使用 | 添加辅助损失函数 |
| 通信开销 | 分布式训练时专家在不同GPU | 专家并行策略 |
| 泛化性 | 专家过度专业化 | 正则化技术 |
负载均衡损失:
其中 CV 是变异系数,鼓励专家使用均匀。
七、技术对比与选择指南
7.1 综合对比表
| 技术 | 加速比 | 显存节省 | 质量损失 | 实现难度 | 适用场景 |
|---|---|---|---|---|---|
| KV Cache | 50x+ | 0% | 0% | ⭐ | 所有自回归模型(必备) |
| MQA | 2x | 96% | 3-5% | ⭐⭐ | 极致推理速度场景 |
| GQA | 1.3x | 75% | <1% | ⭐⭐ | 推荐,平衡方案 |
| Sparse Attn | 10x+ | 50%+ | 0-5% | ⭐⭐⭐⭐ | 超长文本(100K+) |
| MoE | 5x | 70% | 0% | ⭐⭐⭐⭐⭐ | 超大模型,计算受限 |
7.2 选择决策树
graph TD
Start{需求是什么?} --> Q1{序列长度?}
Q1 -->|<4K| Short[标准场景]
Q1 -->|4K-32K| Medium[中长文本]
Q1 -->|>32K| Long[超长文本]
Short --> Q2{显存限制?}
Q2 -->|宽松| Use_MHA[使用标准MHA<br/>+ KV Cache]
Q2 -->|紧张| Use_GQA[使用GQA<br/>+ KV Cache]
Medium --> Q3{质量要求?}
Q3 -->|最高| MHA_Long[MHA + KV Cache]
Q3 -->|平衡| GQA_Long[GQA + Sliding Window]
Long --> Sparse[Sparse Attention<br/>必选方案]
Start --> Q4{是否超大模型?}
Q4 -->|>100B| Consider_MoE[考虑MoE架构]
style Use_GQA fill:#fff59d
style GQA_Long fill:#fff59d
style Sparse fill:#a5d6a7
style Consider_MoE fill:#81c784
7.3 工业界实践
OpenAI GPT系列:
- GPT-3: MHA + KV Cache
- GPT-3.5/4: 推测 MQA/GQA + Sparse + MoE
Meta LLaMA系列:
- LLaMA: MHA + KV Cache
- LLaMA2: GQA-8 + KV Cache (黄金组合)
- LLaMA3: GQA + 更长上下文
Google PaLM/Gemini:
- PaLM: MQA + KV Cache
- PaLM2: MQA改进版
Anthropic Claude:
- Claude 1/2: 推测 GQA + Sparse
- Claude 3: Sparse Attention (200K上下文)
八、实战:构建一个优化的Transformer
完整代码
class OptimizedTransformerBlock(nn.Module):
"""
集成GQA + KV Cache的优化Transformer Block
"""
def __init__(self, d_model, n_heads, n_kv_groups, d_ff, dropout=0.1):
super().__init__()
# GQA
self.gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, cache=None, use_cache=False):
# Self-attention with cache
attn_out, new_cache = self.gqa(x, cache=cache, use_cache=use_cache)
x = self.norm1(x + self.dropout(attn_out))
# FFN
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
if use_cache:
return x, new_cache
return x
# LLaMA2-7B配置
d_model = 4096
n_heads = 32
n_kv_groups = 8 # GQA-8
d_ff = 11008
n_layers = 32
# 构建完整模型
class OptimizedLLM(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
OptimizedTransformerBlock(d_model, n_heads, n_kv_groups, d_ff)
for _ in range(n_layers)
])
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids, caches=None, use_cache=False):
x = self.embedding(input_ids)
new_caches = []
for i, layer in enumerate(self.layers):
cache = caches[i] if caches else None
if use_cache:
x, new_cache = layer(x, cache=cache, use_cache=True)
new_caches.append(new_cache)
else:
x = layer(x)
logits = self.lm_head(x)
if use_cache:
return logits, new_caches
return logits
# 使用示例
vocab_size = 32000
model = OptimizedLLM(vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers)
print(f"模型参数量: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
# 输出: 模型参数量: 6.74B (接近LLaMA2-7B)
九、总结与展望
9.1 核心技术总结
mindmap
root((现代LLM优化))
推理加速
KV Cache
缓存历史KV
O(n²)→O(n)
Flash Attention
IO优化
SRAM计算
显存优化
MQA
共享KV
节省96%
GQA
分组共享
节省75%
长文本
Sparse Attention
滑动窗口
全局+局部
RoPE
相对位置编码
超大模型
MoE
稀疏激活
专家路由
模型并行
专家并行
张量并行
9.2 未来趋势
1. 更长的上下文
- 目标: 100万token上下文
- 技术: 混合注意力模式、分层记忆
2. 更高效的架构
- 线性Attention (RWKV, RetNet)
- 状态空间模型 (Mamba)
3. 动态计算
- 早停机制 (Early Exit)
- 自适应计算 (Adaptive Computation)
4. 硬件协同优化
- 定制芯片(TPU, Groq)
- 混合精度(FP8, INT4)
十、练习与资源
练习题
1. 计算KV Cache节省
# 给定LLaMA2-13B配置,计算生成1000个token的KV Cache大小
# n_layers=40, n_heads=40, d_k=128, seq_len=1000
2. 实现Sliding Window Mask
def create_sliding_window_mask(seq_len, window_size):
# TODO: 实现并可视化
pass
3. 对比GQA不同配置
# 实验GQA-4 vs GQA-8 vs MHA的性能和显存