从自回归生成说起
在前面的章节中,我们学习了大模型的核心原理:给定前面的Token序列,预测下一个Token。但是,当我们实际使用大模型进行文本生成时,会遇到一个严重的性能问题。
自回归生成的过程
假设我们要让模型生成一句话:"今天天气真好"(5个Token)
第1步:输入提示词"今天"
- 输入序列:["今天"](1个Token)
- 模型计算注意力,输出:["天气"]
第2步:继续生成
- 输入序列:["今天", "天气"](2个Token)
- 模型重新计算这2个Token的注意力,输出:["真"]
第3步:继续生成
- 输入序列:["今天", "天气", "真"](3个Token)
- 模型重新计算这3个Token的注意力,输出:["好"]
注意到问题了吗?每次生成新Token时,模型都要重新计算前面所有Token的注意力!
重复计算的代价
让我们用数学来量化这个问题。
注意力计算回顾
在注意力机制中,对于每个Token,我们需要计算:
重复计算示例
假设我们要生成长度为100的文本,每个生成步骤的计算量:
| 步骤 | 序列长度 | 需要计算的Token数 | 累计计算量 |
|---|---|---|---|
| 1 | 1 | 1 | 1 |
| 2 | 2 | 2 | 1+2=3 |
| 3 | 3 | 3 | 1+2+3=6 |
| ... | ... | ... | ... |
| 100 | 100 | 100 | 1+2+...+100=5050 |
总计算量: 次Token的注意力计算
但实际上,真正需要的计算量只有100次!因为:
- 第1个Token的K、V计算一次就够了
- 第2个Token的K、V计算一次就够了
- ...
- 第100个Token的K、V计算一次就够了
问题的根源:前面Token的K和V在每一步都被重新计算,但它们的值根本不会改变!
KV Cache:缓存已计算的K和V
核心思想
KV Cache的思想非常简单:
既然每个Token的K和V只需要计算一次,那就把它们缓存起来,下次直接使用!
具体来说:
-
第1步:生成第1个Token
- 计算:
- 缓存:保存
- 输出:新Token
-
第2步:生成第2个Token
- 计算:只计算新Token的
- 缓存:保存 ,现在缓存中有
- 使用缓存:直接读取 ,无需重新计算
- 输出:新Token
-
第3步:生成第3个Token
- 计算:只计算新Token的
- 缓存:保存 ,现在缓存中有
- 使用缓存:直接读取 ,无需重新计算
- 输出:新Token
性能提升
使用KV Cache后,生成100个Token的计算量:
| 步骤 | 需要新计算的KV | 从缓存读取的KV | 总计算量 |
|---|---|---|---|
| 1 | 1 | 0 | 1 |
| 2 | 1 | 1 | 2 |
| 3 | 1 | 2 | 3 |
| ... | ... | ... | ... |
| 100 | 1 | 99 | 100 |
总计算量:100次(从5050次降到100次,加速50倍!)
KV Cache的数据结构
缓存的形状
对于一个多头注意力层:
- 输入序列长度:(已生成的Token数)
- 模型维度:(例如:4096)
- 注意力头数:(例如:32)
- 每个头的维度:(例如:128)
- 层数:(例如:32层)
每一层的KV Cache形状:
全模型的KV Cache形状(所有层):
内存占用计算
假设使用FP16精度(每个数2字节),模型参数:
- 层
- 头
- 序列长度
单个样本的KV Cache内存:
Batch推理的内存(batch_size=32):
这就是为什么大模型推理需要大显存的原因之一!
实际例子:不同模型的KV Cache
| 模型 | 层数 | d_model | 头数 | 序列长度 | 单样本KV Cache | Batch=32 |
|---|---|---|---|---|---|---|
| GPT-2 Small | 12 | 768 | 12 | 1024 | 36 MB | 1.1 GB |
| LLaMA-7B | 32 | 4096 | 32 | 2048 | 1 GB | 32 GB |
| LLaMA-13B | 40 | 5120 | 40 | 2048 | 1.6 GB | 51 GB |
| LLaMA-65B | 80 | 8192 | 64 | 2048 | 5.1 GB | 163 GB |
| GPT-3 175B | 96 | 12288 | 96 | 2048 | 9.2 GB | 294 GB |
可以看到,对于超大模型,KV Cache可能比模型权重本身还要占用更多显存!
KV Cache的实现细节
伪代码实现
class MultiHeadAttentionWithKVCache:
def __init__(self, d_model, num_heads):
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 权重矩阵
self.W_Q = Parameter(torch.randn(d_model, d_model))
self.W_K = Parameter(torch.randn(d_model, d_model))
self.W_V = Parameter(torch.randn(d_model, d_model))
self.W_O = Parameter(torch.randn(d_model, d_model))
# KV Cache(初始为空)
self.k_cache = [] # List of cached K tensors
self.v_cache = [] # List of cached V tensors
def forward(self, x, use_cache=True):
"""
x: 输入Token的embedding,形状 (batch_size, 1, d_model)
注意:推理时每次只输入1个新Token
"""
batch_size = x.shape[0]
# 计算新Token的Q、K、V
Q_new = x @ self.W_Q # (batch_size, 1, d_model)
K_new = x @ self.W_K # (batch_size, 1, d_model)
V_new = x @ self.W_V # (batch_size, 1, d_model)
# 重塑为多头形状
Q_new = Q_new.view(batch_size, 1, self.num_heads, self.d_k)
K_new = K_new.view(batch_size, 1, self.num_heads, self.d_k)
V_new = V_new.view(batch_size, 1, self.num_heads, self.d_k)
if use_cache:
# 将新的K、V添加到缓存
self.k_cache.append(K_new)
self.v_cache.append(V_new)
# 拼接所有历史K、V
K = torch.cat(self.k_cache, dim=1) # (batch, seq_len, heads, d_k)
V = torch.cat(self.v_cache, dim=1)
else:
K = K_new
V = V_new
# 计算注意力
# Q: (batch, 1, heads, d_k) - 只有新Token的Query
# K: (batch, seq_len, heads, d_k) - 所有Token的Key(包括历史)
scores = torch.einsum('bqhd,bkhd->bhqk', Q_new, K) / math.sqrt(self.d_k)
# scores: (batch, heads, 1, seq_len)
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.einsum('bhqk,bkhd->bqhd', attn_weights, V)
# output: (batch, 1, heads, d_k)
# 重塑并投影
output = output.reshape(batch_size, 1, self.d_model)
output = output @ self.W_O
return output
def clear_cache(self):
"""清空KV Cache,开始新的生成任务"""
self.k_cache = []
self.v_cache = []
关键点解析
-
只计算新Token的K和V:
K_new = x @ self.W_K # x的形状是(batch, 1, d_model),只有1个Token -
从缓存读取历史K、V:
K = torch.cat(self.k_cache, dim=1) # 拼接所有历史Token的K -
注意力计算使用完整的K、V:
# Q只有1个Token(新Token) # K、V有n个Token(所有历史Token + 新Token) scores = torch.einsum('bqhd,bkhd->bhqk', Q_new, K)
KV Cache与位置编码的关系
这里有一个非常重要的问题:当我们使用KV Cache时,位置编码怎么办?
绝对位置编码的问题
回顾一下绝对位置编码(Sinusoidal或Learned):
其中 是第 个位置的位置编码。
问题:当使用KV Cache时,每次只输入1个新Token,但这个Token的绝对位置在不断变化!
举例:
- 第1步:输入Token的位置是0,PE[0]
- 第2步:输入Token的位置是1,PE[1]
- 第3步:输入Token的位置是2,PE[2]
- ...
看起来没问题?但实际上有个隐藏的问题:
缓存的K、V已经包含了位置编码信息:
- 第1个Token的K、V计算时使用了PE[0]
- 第2个Token的K、V计算时使用了PE[1]
- ...
所以,绝对位置编码在KV Cache场景下是兼容的,但需要注意:
- 必须传入正确的位置索引(当前Token是第几个)
- 位置编码表必须足够长(支持最大序列长度)
位置编码表(Position Embedding Table)
在实际实现中,位置编码通常预先计算并存储在一个位置编码表中:
class PositionalEncoding:
def __init__(self, d_model, max_seq_len=5000):
# 预先计算所有位置的编码
self.pe_table = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1) # (max_seq_len, 1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
)
# 偶数维度用sin
self.pe_table[:, 0::2] = torch.sin(position * div_term)
# 奇数维度用cos
self.pe_table[:, 1::2] = torch.cos(position * div_term)
def get_position_encoding(self, position):
"""
position: 当前Token的位置索引(标量)
返回: 该位置的位置编码向量 (d_model,)
"""
return self.pe_table[position]
使用KV Cache时的流程:
# 第1步:生成第1个Token(位置0)
x_0 = token_embedding(token_0) + pe_table[0] # 加上位置0的编码
output_0 = attention(x_0)
# 第2步:生成第2个Token(位置1)
x_1 = token_embedding(token_1) + pe_table[1] # 加上位置1的编码
output_1 = attention(x_1) # 使用KV Cache,读取位置0的K、V
# 第3步:生成第3个Token(位置2)
x_2 = token_embedding(token_2) + pe_table[2] # 加上位置2的编码
output_2 = attention(x_2) # 使用KV Cache,读取位置0、1的K、V
RoPE与KV Cache
RoPE(Rotary Position Embedding)是一种更现代的位置编码方式,它在计算注意力时动态地将位置信息旋转到Q和K中。
RoPE的优势:
- 相对位置敏感:注意力分数只依赖于Token之间的相对距离
- 无需位置编码表:位置信息通过旋转矩阵动态计算
- 与KV Cache完美兼容:缓存的K已经包含了正确的位置信息
RoPE在KV Cache中的应用:
def apply_rotary_pos_emb(q, k, position):
"""
应用旋转位置编码
q, k: (batch, seq_len, heads, d_k)
position: 当前Token的绝对位置
"""
# 计算旋转角度
theta = position / (10000 ** (torch.arange(0, d_k, 2) / d_k))
# 构造旋转矩阵
cos = torch.cos(theta)
sin = torch.sin(theta)
# 旋转Q和K
q_rot = apply_rotation(q, cos, sin)
k_rot = apply_rotation(k, cos, sin)
return q_rot, k_rot
# 使用KV Cache时
Q_new = apply_rotary_pos_emb(Q_new, position=current_position)
K_new = apply_rotary_pos_emb(K_new, position=current_position)
# 缓存已旋转的K
k_cache.append(K_new)
关键点:
- 每个Token的K在计算时就已经包含了其位置信息(通过旋转)
- 缓存的K不需要再次旋转
- 新Token的Q旋转时使用其当前位置
- 注意力计算时,Q和K的相对位置关系自动体现在旋转角度的差值中
位置编码与KV Cache的总结
| 位置编码类型 | 与KV Cache的兼容性 | 注意事项 |
|---|---|---|
| 绝对位置编码(Sinusoidal) | ✅ 兼容 | 需要预先计算位置编码表,传入正确的位置索引 |
| 绝对位置编码(Learned) | ✅ 兼容 | 同上,位置编码表是可学习参数 |
| RoPE | ✅ 完美兼容 | 缓存的K已包含位置信息,无需额外处理 |
| ALiBi | ✅ 完美兼容 | 位置偏置在计算注意力时动态添加 |
KV Cache的变体与优化
1. Multi-Query Attention (MQA)
问题:标准多头注意力中,每个头都有自己的K和V,导致KV Cache很大。
解决方案:所有头共享一组K和V。
优势:
- KV Cache大小减少 倍(头数)
- 例如:32头变成1组K、V,内存占用减少32倍
劣势:
- 表达能力下降(所有头看到相同的K、V)
2. Grouped-Query Attention (GQA)
折中方案:将头分成若干组,每组共享K和V。
例如:32个头分成4组,每组8个头共享一组K、V。
优势:
- KV Cache减少 倍(是组数)
- 保留了一定的多头表达能力
实际应用:
- LLaMA-2:使用GQA(8组)
- Mistral:使用GQA(8组)
3. Paged Attention
问题:KV Cache是连续内存块,当序列很长时,可能无法分配足够大的连续内存。
解决方案:将KV Cache分成固定大小的"页",类似操作系统的虚拟内存。
# 传统KV Cache:连续内存
k_cache = torch.zeros(batch, seq_len, heads, d_k) # 需要连续的seq_len空间
# Paged Attention:分页存储
page_size = 16 # 每页存储16个Token的K/V
num_pages = seq_len // page_size
k_cache_pages = [
torch.zeros(batch, page_size, heads, d_k) for _ in range(num_pages)
]
优势:
- 内存碎片友好
- 支持动态序列长度
- 减少内存浪费(不需要预先分配最大长度)
实际应用:
- vLLM:使用Paged Attention实现高效的批量推理
实际生成示例
让我们通过一个完整的例子来理解KV Cache的工作流程。
任务:生成文本 "今天天气真好"
初始状态:
- Prompt(用户输入):无(从头开始生成)
- KV Cache:空
Step 1:生成"今天"
输入:<BOS>(开始标记)
位置编码:PE[0]
计算:Q_0, K_0, V_0
KV Cache:K_0, V_0
输出:"今天"
Step 2:生成"天气"
输入:"今天"
位置编码:PE[1]
计算:Q_1, K_1, V_1(只计算新Token)
KV Cache:[K_0, K_1], [V_0, V_1](添加新的K、V)
注意力:Q_1 attend to [K_0, K_1]
输出:"天气"
Step 3:生成"真"
输入:"天气"
位置编码:PE[2]
计算:Q_2, K_2, V_2
KV Cache:[K_0, K_1, K_2], [V_0, V_1, V_2]
注意力:Q_2 attend to [K_0, K_1, K_2]
输出:"真"
Step 4:生成"好"
输入:"真"
位置编码:PE[3]
计算:Q_3, K_3, V_3
KV Cache:[K_0, K_1, K_2, K_3], [V_0, V_1, V_2, V_3]
注意力:Q_3 attend to [K_0, K_1, K_2, K_3]
输出:"好"
性能对比
不使用KV Cache(每步重新计算):
- Step 1:计算1个Token的KV → 1次
- Step 2:计算2个Token的KV → 2次
- Step 3:计算3个Token的KV → 3次
- Step 4:计算4个Token的KV → 4次
- 总计:1+2+3+4 = 10次KV计算
使用KV Cache:
- Step 1:计算K_0, V_0 → 1次
- Step 2:计算K_1, V_1 → 1次
- Step 3:计算K_2, V_2 → 1次
- Step 4:计算K_3, V_3 → 1次
- 总计:4次KV计算(加速2.5倍)
对于更长的序列(例如2048个Token),加速比接近1024倍!
KV Cache的管理策略
1. 固定长度截断
当序列超过最大长度时,丢弃最早的Token:
max_cache_len = 2048
if len(k_cache) >= max_cache_len:
# 移除最早的Token
k_cache.pop(0)
v_cache.pop(0)
# 添加新Token
k_cache.append(K_new)
v_cache.append(V_new)
优势:简单,内存可控 劣势:可能丢失重要的历史信息
2. 滑动窗口
只保留最近的N个Token:
window_size = 512
if len(k_cache) >= window_size:
k_cache = k_cache[-window_size:]
v_cache = v_cache[-window_size:]
优势:专注于局部上下文 劣势:无法建模长距离依赖
3. 重要性采样
根据注意力权重,保留重要的Token:
def prune_cache_by_attention(k_cache, v_cache, attention_weights, keep_ratio=0.5):
# 计算每个Token的平均注意力分数
importance = attention_weights.mean(dim=(0, 1)) # (seq_len,)
# 选择重要性最高的Token
num_keep = int(len(k_cache) * keep_ratio)
keep_indices = torch.topk(importance, num_keep).indices
# 只保留重要的Token
k_cache = [k_cache[i] for i in keep_indices]
v_cache = [v_cache[i] for i in keep_indices]
return k_cache, v_cache
优势:保留关键信息 劣势:计算复杂度高
4. H2O(Heavy-Hitter Oracle)
最新的研究表明,大多数注意力权重集中在少数"重要"Token上:
- Heavy Hitters:注意力权重最高的Token(例如标点符号、关键词)
- Recent Tokens:最近生成的Token
策略:只缓存Heavy Hitters + Recent Tokens
def h2o_cache_management(k_cache, v_cache, attention_weights,
heavy_ratio=0.1, recent_ratio=0.1):
seq_len = len(k_cache)
# 计算累积注意力分数
cumulative_attention = attention_weights.sum(dim=(0, 1, 2)) # (seq_len,)
# 选择Heavy Hitters
num_heavy = int(seq_len * heavy_ratio)
heavy_indices = torch.topk(cumulative_attention, num_heavy).indices
# 选择Recent Tokens
num_recent = int(seq_len * recent_ratio)
recent_indices = torch.arange(seq_len - num_recent, seq_len)
# 合并索引
keep_indices = torch.cat([heavy_indices, recent_indices]).unique()
# 只保留选中的Token
k_cache = [k_cache[i] for i in keep_indices]
v_cache = [v_cache[i] for i in keep_indices]
return k_cache, v_cache
优势:
- 大幅减少缓存大小(可减少90%)
- 几乎不损失性能
小结
KV Cache的核心思想
- 问题:自回归生成时,每步都重新计算前面所有Token的K和V,导致大量重复计算
- 解决方案:缓存已计算的K和V,每步只计算新Token的K和V
- 性能提升:对于长度N的序列,从 降到 ,加速可达N/2倍
位置编码与KV Cache
- 绝对位置编码:需要预先计算位置编码表,确保每个Token使用正确的位置索引
- RoPE:通过旋转矩阵动态编码位置,与KV Cache完美兼容
- ALiBi:通过注意力偏置编码位置,与KV Cache完美兼容
内存优化技术
- Multi-Query Attention (MQA):所有头共享K和V,减少KV Cache 倍
- Grouped-Query Attention (GQA):头分组共享K和V,平衡性能和内存
- Paged Attention:分页存储KV Cache,减少内存碎片
- H2O:只缓存重要Token,减少90%缓存大小
实际应用
- OpenAI GPT:使用KV Cache + 绝对位置编码
- Meta LLaMA:使用KV Cache + RoPE + GQA
- vLLM:使用KV Cache + Paged Attention,实现高效批量推理
- DeepSeek:使用KV Cache + MLA(Multi-head Latent Attention),进一步压缩KV Cache
KV Cache是大模型推理加速的基石技术,几乎所有现代推理系统都依赖它来实现实时交互。理解KV Cache的原理,对于优化大模型部署和推理性能至关重要。