PagedAttention、FlashAttention与投机采样:推理优化三大技术
大模型推理的三大瓶颈
在上一章中,我们学习了KV Cache如何通过缓存已计算的K和V来加速推理。但即使有了KV Cache,大模型推理仍然面临三个核心瓶颈:
瓶颈1:内存管理效率低(PagedAttention解决)
问题:传统的KV Cache需要预先分配连续内存块
# 传统方式:预分配最大长度的内存
max_seq_len = 4096
k_cache = torch.zeros(batch_size, max_seq_len, num_heads, head_dim)
v_cache = torch.zeros(batch_size, max_seq_len, num_heads, head_dim)
带来的问题:
- 即使实际只生成100个Token,也要分配4096个Token的空间
- 内存利用率低(浪费率可达60-80%)
- 无法支持动态长度的批处理
瓶颈2:显存访问效率低(FlashAttention解决)
问题:标准注意力计算需要大量的显存读写
注意力计算公式:
传统实现的显存访问:
# 步骤1:计算注意力分数(写入HBM)
scores = Q @ K.T # 形状:(seq_len, seq_len),写入显存
# 步骤2:Softmax(从HBM读取,写回HBM)
attn_weights = softmax(scores / sqrt(d_k)) # 读取scores,写回显存
# 步骤3:加权求和(从HBM读取V和attn_weights)
output = attn_weights @ V # 读取attn_weights和V
问题:中间结果(scores、attn_weights)频繁读写显存,显存带宽成为瓶颈。
对于序列长度 ,中间矩阵 占用 16MB(FP16),需要反复读写。
瓶颈3:自回归生成速度慢(投机采样解决)
问题:每次只生成1个Token,GPU利用率低
# 自回归生成:串行处理
for i in range(100):
next_token = model(input_tokens) # 生成1个Token
input_tokens.append(next_token) # 等待上一步完成
带来的问题:
- GPU大部分时间在等待内存访问
- 计算单元利用率低(<20%)
- 生成100个Token需要100次模型前向传播
本章将详细讲解这三大优化技术如何解决这些问题。
技术1:PagedAttention - 虚拟内存管理
核心思想
PagedAttention借鉴操作系统的虚拟内存思想,将KV Cache分成固定大小的"页(Page)",按需分配。
类比理解:
- 传统KV Cache:像酒店提前包下整层楼(100间房),即使只住10个人
- PagedAttention:像酒店按需分配房间,只分配实际需要的房间数
传统KV Cache的问题
假设我们要处理一个batch的请求:
| 请求ID | 实际长度 | 分配内存 | 利用率 |
|---|---|---|---|
| 请求1 | 100 Token | 4096 Token | 2.4% |
| 请求2 | 500 Token | 4096 Token | 12.2% |
| 请求3 | 50 Token | 4096 Token | 1.2% |
| 请求4 | 200 Token | 4096 Token | 4.9% |
总内存:4 × 4096 × 层数 × 头数 × 头维度 实际使用:(100+500+50+200) × 层数 × 头数 × 头维度 利用率:850 / (4×4096) = 5.2%
PagedAttention的解决方案
步骤1:将KV Cache分页
将连续的KV Cache分成固定大小的页(例如每页16个Token):
# 页的大小(每页存储多少个Token)
PAGE_SIZE = 16
# 传统KV Cache:连续内存
k_cache_traditional = torch.zeros(batch, max_seq_len, num_heads, head_dim)
# PagedAttention:分页存储
# 每一页是一个独立的tensor
k_cache_paged = [
torch.zeros(batch, PAGE_SIZE, num_heads, head_dim)
for _ in range(max_seq_len // PAGE_SIZE)
]
步骤2:维护页表(Page Table)
为每个请求维护一个页表,记录其KV Cache使用了哪些页:
class PageTable:
def __init__(self):
self.pages = [] # 存储该请求使用的物理页索引
def allocate_page(self, physical_page_id):
"""分配一个新页"""
self.pages.append(physical_page_id)
def get_physical_page(self, virtual_page_id):
"""将虚拟页号转换为物理页号"""
return self.pages[virtual_page_id]
# 每个请求有自己的页表
request_1_page_table = PageTable()
request_2_page_table = PageTable()
步骤3:按需分配页
当请求生成新Token时,动态分配页:
class PagedKVCache:
def __init__(self, num_layers, num_heads, head_dim, page_size=16):
self.page_size = page_size
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
# 物理内存池:所有空闲的页
self.free_pages = []
self.used_pages = {} # page_id -> tensor
# 为每个请求维护页表
self.page_tables = {} # request_id -> PageTable
def allocate_page(self):
"""从内存池中分配一个空闲页"""
if self.free_pages:
page_id = self.free_pages.pop()
else:
# 创建新页
page_id = len(self.used_pages)
self.used_pages[page_id] = torch.zeros(
self.num_layers, self.page_size, self.num_heads, self.head_dim
)
return page_id
def append_kv(self, request_id, layer_id, k_new, v_new):
"""
为请求添加新的K、V
k_new, v_new: (1, num_heads, head_dim) - 单个Token的K、V
"""
if request_id not in self.page_tables:
# 新请求:创建页表并分配第一页
self.page_tables[request_id] = PageTable()
page_id = self.allocate_page()
self.page_tables[request_id].allocate_page(page_id)
page_table = self.page_tables[request_id]
num_tokens = len(page_table.pages) * self.page_size
# 计算当前Token在哪一页
page_idx = num_tokens // self.page_size
offset_in_page = num_tokens % self.page_size
# 如果当前页已满,分配新页
if offset_in_page == 0 and num_tokens > 0:
page_id = self.allocate_page()
page_table.allocate_page(page_id)
page_idx = len(page_table.pages) - 1
# 获取物理页
physical_page_id = page_table.get_physical_page(page_idx)
physical_page = self.used_pages[physical_page_id]
# 写入K、V
physical_page[layer_id, offset_in_page] = k_new
def get_kv(self, request_id, layer_id):
"""
获取请求的所有K、V
返回:(seq_len, num_heads, head_dim)
"""
page_table = self.page_tables[request_id]
k_list = []
for page_idx in range(len(page_table.pages)):
physical_page_id = page_table.get_physical_page(page_idx)
physical_page = self.used_pages[physical_page_id]
# 读取该页的K
k_page = physical_page[layer_id] # (page_size, num_heads, head_dim)
k_list.append(k_page)
# 拼接所有页
k_all = torch.cat(k_list, dim=0) # (seq_len, num_heads, head_dim)
return k_all
def free_request(self, request_id):
"""释放请求占用的所有页"""
if request_id in self.page_tables:
page_table = self.page_tables[request_id]
# 将所有页归还到空闲池
self.free_pages.extend(page_table.pages)
del self.page_tables[request_id]
PagedAttention的优势
1. 内存利用率提升
示例:处理4个请求,实际长度分别为100、500、50、200
传统方式:
- 分配内存:4 × 4096 = 16384 个Token位置
- 实际使用:100 + 500 + 50 + 200 = 850 个Token位置
- 利用率:5.2%
PagedAttention(页大小=16):
- 请求1:需要 100/16 ≈ 7页
- 请求2:需要 500/16 ≈ 32页
- 请求3:需要 50/16 ≈ 4页
- 请求4:需要 200/16 ≈ 13页
- 总页数:56页
- 分配内存:56 × 16 = 896 个Token位置
- 实际使用:850 个Token位置
- 利用率:95%(提升18倍!)
2. 支持动态批处理
传统方式必须等待所有请求完成才能释放内存,PagedAttention可以:
# 请求1完成,立即释放内存
paged_cache.free_request(request_1)
# 立即用释放的内存处理新请求
paged_cache.append_kv(request_5, layer_id, k_new, v_new)
3. 内存共享
对于共享前缀的请求,可以共享页:
# 请求1:"解释一下深度学习"
# 请求2:"解释一下深度学习的原理"
# 前缀"解释一下深度学习"的KV Cache可以共享
prefix_pages = request_1_page_table.pages[:3] # 前3页
request_2_page_table.pages = prefix_pages + [new_page] # 共享前缀,只分配新页
应用场景:
- 批量处理相似问题
- Few-shot Learning(示例前缀相同)
- 系统提示词(所有请求共享)
实际性能提升
根据vLLM(实现了PagedAttention)的论文数据:
| 指标 | 传统KV Cache | PagedAttention | 提升 |
|---|---|---|---|
| 内存利用率 | 20% | 90% | 4.5倍 |
| 吞吐量(请求/秒) | 0.5 | 2.3 | 4.6倍 |
| 支持的最大batch size | 8 | 64 | 8倍 |
技术2:FlashAttention - 显存访问优化
GPU内存层次结构
在理解FlashAttention之前,需要深入了解GPU的硬件架构和内存系统。
GPU整体架构
现代GPU(以NVIDIA A100为例)的基本结构:
┌──────────────────────────────────────────────────────────────┐
│ GPU芯片 │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ SM (Streaming Multiprocessor) 1 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌─────────────────┐ │ │
│ │ │ CUDA │ │ CUDA │ │ Shared Memory │ │ │
│ │ │ Cores │ │ Cores │ │ (SRAM) │ │ │
│ │ │ (计算单元)│ │ (计算单元)│ │ ~200 KB │ │ │
│ │ └──────────┘ └──────────┘ └─────────────────┘ │ │
│ │ ┌────────────────────────────────────────────┐ │ │
│ │ │ Registers (寄存器) │ │ │
│ │ │ ~256 KB │ │ │
│ │ └────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ SM 2 (类似结构) │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ ... (A100有108个SM) │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ L2 Cache (40 MB) │ │
│ └─────────────────────────────────────────────────────┘ │
│ ↓↑ │
└──────────────────────│────────────────────────────────────────┘
│
内存总线
│
┌─────────────▼──────────────┐
│ HBM (High Bandwidth Memory) │
│ (显存: 40GB/80GB) │
└──────────────────────────────┘
关键组件:
-
SM (Streaming Multiprocessor):GPU的核心计算单元
- A100有108个SM
- 每个SM包含多个CUDA核心(计算单元)
- 每个SM有自己的SRAM(共享内存和寄存器)
-
CUDA Core:实际执行计算的单元
- 类似于CPU的核心
- 一个SM包含64个CUDA核心
-
L2 Cache:所有SM共享的缓存
- 40 MB(A100)
- 作为SRAM和HBM之间的中间层
HBM (High Bandwidth Memory) - 显存
什么是HBM?
HBM是GPU的主存储器,相当于计算机的RAM,但专为GPU设计。
物理特性:
HBM芯片结构:
┌──────────────────┐
│ GPU Die │ ← GPU核心芯片
├──────────────────┤
│ Interposer │ ← 硅中介层(高速连接)
├──────────────────┤
│ HBM Stack 1 │ ← 内存堆叠芯片
│ HBM Stack 2 │
│ HBM Stack 3 │
│ HBM Stack 4 │
└──────────────────┘
特点:
- 垂直堆叠多层DRAM芯片
- 通过TSV (Through-Silicon Via) 连接
- 极宽的总线 (5120-bit on A100)
HBM的特点:
-
容量大:
- A100:40GB/80GB
- H100:80GB
- 存储模型参数、激活值、KV Cache等
-
带宽高(相对于传统GDDR):
- A100 HBM2e:1.5-2.0 TB/s
- H100 HBM3:3.0 TB/s
- 比传统GDDR快2-3倍
-
延迟相对高:
- 访问延迟:~500-800纳秒
- 相比SRAM慢10-100倍
-
物理位置:
- 在GPU芯片外部(但在同一封装内)
- 需要通过内存总线访问
HBM存储的内容:
HBM中存储的数据(AI训练/推理):
1. 模型参数(W, b):几GB到几十GB
2. 激活值(中间结果):几GB
3. KV Cache(推理时):几GB
4. 梯度(训练时):与参数量相当
5. 优化器状态(训练时):参数量的2-3倍
例如:LLaMA-7B推理
- 模型参数:14 GB (FP16)
- KV Cache:~4 GB (batch=8, seq_len=2048)
- 激活值:~2 GB
总计:~20 GB HBM
SRAM (Static RAM) - 片上内存
什么是SRAM?
SRAM是集成在GPU芯片内部的高速缓存,直接在计算单元旁边。
物理特性:
SRAM与HBM的对比:
┌─────────────────────────┐
│ GPU芯片内部 │
│ ┌──────┐ ┌────────┐ │
│ │ CUDA │←→│ SRAM │ │ ← 片上,极近
│ │ Core │ │(100KB) │ │ 延迟: 几纳秒
│ └──────┘ └────────┘ │
└─────────────────────────┘
↓↑
内存总线 (较慢)
↓↑
┌─────────────────────────┐
│ HBM (40GB) │ ← 片外(同封装)
└─────────────────────────┘ 延迟: 几百纳秒
SRAM的类型:
-
Shared Memory(共享内存):
- 每个SM有一块共享内存
- A100:每个SM约164 KB
- 可以被该SM内的所有线程访问
- 程序员可以显式控制
-
Registers(寄存器):
- 每个SM有大量寄存器
- A100:每个SM约256 KB
- 每个线程独占自己的寄存器
- 访问速度最快
-
L1 Cache:
- 与Shared Memory共享物理空间
- 硬件自动管理
- 可配置大小比例
SRAM的特点:
-
速度极快:
- 访问延迟:1-2纳秒
- 带宽:A100每个SM约19 TB/s
- 比HBM快10-100倍
-
容量极小:
- 每个SM:~200 KB(Shared Memory)
- 全GPU:108个SM × 200KB ≈ 20 MB
- 比HBM小2000倍(40GB vs 20MB)
-
物理位置:
- 在GPU芯片内部
- 紧邻计算单元
- 无需经过内存总线
SRAM使用的数据:
SRAM中存储的数据(临时):
1. 正在计算的小块数据
2. 循环变量、临时结果
3. 频繁访问的数据
例如:FlashAttention
- Q的一块:64 × 128 × 2 bytes = 16 KB
- K的一块:64 × 128 × 2 bytes = 16 KB
- V的一块:64 × 128 × 2 bytes = 16 KB
- 中间结果:64 × 64 × 2 bytes = 8 KB
总计:~56 KB(远小于200KB限制)✓
内存层次对比
┌──────────────┬────────────┬──────────┬───────────┬─────────────┐
│ 内存类型 │ 容量 │ 带宽 │ 延迟 │ 物理位置 │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ Registers │ ~256 KB/SM │ 极高 │ 1 ns │ SM内部 │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ Shared Mem │ ~200 KB/SM │ 19 TB/s │ 1-2 ns │ SM内部 │
│ (SRAM) │ 全GPU~20MB │ /SM │ │ │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ L2 Cache │ 40 MB │ ~7 TB/s │ ~50 ns │ GPU芯片内 │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ HBM │ 40-80 GB │ 1.5 TB/s │ 500-800ns │ 芯片外 │
│ (显存) │ │ │ │ (同封装内) │
└──────────────┴────────────┴──────────┴───────────┴─────────────┘
速度对比:
Registers ≈ Shared Memory > L2 Cache >>> HBM
1倍 1倍 3倍 12倍慢
关键数据(A100 GPU):
- SRAM带宽:19 TB/s (每个SM)
- HBM带宽:1.5 TB/s (整个GPU)
- 速度差距:SRAM比HBM快12倍以上
- 容量差距:HBM比SRAM大2000倍以上
为什么速度差异如此大?
物理距离:
数据传输距离(简化示意):
SRAM访问:
CUDA Core → 几微米 → SRAM
传输时间:~1纳秒
HBM访问:
CUDA Core → 几毫米 → 内存控制器 → 几毫米 → HBM
传输时间:~500纳秒
差距来源:
1. 物理距离:500倍
2. 总线宽度:SRAM更宽
3. 访问机制:SRAM更简单
4. 电气特性:SRAM电压切换更快
类比理解:
CPU/GPU内存系统 ≈ 厨房
Registers/SRAM(片上内存):
= 厨师手里的锅和案板
- 极快(伸手就拿到)
- 容量小(只能放当前用的食材)
L2 Cache:
= 灶台旁边的调料架
- 较快(转身就拿到)
- 容量中等
HBM(显存):
= 厨房的冰箱
- 较慢(走几步才能拿到)
- 容量大(存放所有食材)
主内存(CPU RAM):
= 储藏室
- 很慢(要出厨房)
- 容量更大
硬盘:
= 超市
- 极慢(要出门)
- 容量巨大
FlashAttention的核心策略
理解了GPU内存层次后,FlashAttention的策略就很清楚了:
问题:
# 标准注意力
scores = Q @ K.T # 2048×2048矩阵 → 写入HBM(慢)
attn = softmax(scores) # 从HBM读取 → 计算 → 写回HBM(慢)
output = attn @ V # 从HBM读取 → 计算(慢)
# 瓶颈:在HBM和SRAM之间来回传输大矩阵
FlashAttention的解决方案:
# 分块加载到SRAM(快)
for Q_block in Q.chunks(): # 64×128 加载到SRAM
for K_block, V_block in zip(...): # 64×128 加载到SRAM
scores_block = Q_block @ K_block.T # 在SRAM内计算(快)
attn_block = softmax(scores_block) # 在SRAM内计算(快)
output += attn_block @ V_block # 在SRAM内计算(快)
# 只在循环结束时写回HBM
# 关键:大部分计算在SRAM内完成,HBM访问极少
结论:尽量在SRAM中完成计算,减少HBM访问次数,这就是FlashAttention快的根本原因。
标准注意力的显存访问
标准实现
def standard_attention(Q, K, V):
# Q, K, V: (batch, seq_len, num_heads, head_dim)
seq_len = Q.shape[1]
# 步骤1:计算注意力分数(O(n²) 空间)
scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
# scores: (batch, seq_len, num_heads, seq_len)
# 写入HBM:seq_len² × batch × num_heads 个float
# 步骤2:Softmax
attn_weights = F.softmax(scores, dim=-1)
# 从HBM读取scores,写回HBM:2次HBM访问
# 【重要】在训练时,attn_weights会被保存在显存中用于反向传播!
# 步骤3:加权求和
output = attn_weights @ V
# 从HBM读取attn_weights和V:2次HBM读取
return output
# 训练时:scores和attn_weights都会被PyTorch保存在显存中
# 推理时:不需要保存(因为不需要反向传播)
显存访问分析
对于序列长度 ,注意力头数 ,头维度 :
| 操作 | HBM读取 | HBM写入 | 总访问量 |
|---|---|---|---|
| Softmax | |||
总HBM访问量:
对于 :
- 中间矩阵 scores:
- 注意力权重 attn_weights:
- 需要多次读写,总访问量 > 1 GB
问题:HBM访问成为瓶颈,计算单元在等待数据。
训练时的显存占用
标准注意力在训练时需要保存的中间结果:
# 前向传播保存的中间结果(用于反向传播)
scores = Q @ K.T # O(n²) - 必须保存
attn_weights = softmax(scores) # O(n²) - 必须保存
output = attn_weights @ V # 输出
# 显存占用(单层,单头):
# - scores: n² × 2 bytes (FP16)
# - attn_weights: n² × 2 bytes (FP16)
# - 总计:2n² × 2 bytes
实际例子(,32层,32头):
仅注意力的中间结果就占用32GB!这还不包括模型参数、梯度、优化器状态等。
推理时的显存占用:
推理时不需要反向传播,因此不需要保存中间结果,显存占用大幅降低。但即使如此,仍然需要多次HBM访问来读写这些中间矩阵。
FlashAttention的核心思想
在SRAM中完成尽可能多的计算,减少HBM访问次数
关键技术:
- 分块(Tiling):将Q、K、V分成小块,每次只加载一块到SRAM
- 重计算(Recomputation):训练时不保存中间结果,反向传播时重新计算(推理无此开销)
- 在线Softmax:逐块计算Softmax,无需保存完整的注意力矩阵
FlashAttention算法详解
核心挑战:如何逐块计算Softmax?
标准Softmax需要全局信息:
问题:分母需要所有元素的和,但我们每次只加载一部分数据。
解决方案:在线更新Softmax统计量
关键观察:当读取新的数据块时,可以增量更新 和
算法步骤
输入:
- (序列长度 ,维度 )
- 块大小: (Q的块大小), (K、V的块大小)
输出:
- (注意力输出)
算法流程:
def flash_attention(Q, K, V, block_size=64):
"""
Q, K, V: (seq_len, head_dim)
block_size: SRAM块大小
"""
seq_len, head_dim = Q.shape
num_blocks = (seq_len + block_size - 1) // block_size
# 输出和统计量(在HBM中,但会逐块更新)
O = torch.zeros_like(Q)
m = torch.full((seq_len,), -float('inf')) # 每行的最大值
ell = torch.zeros(seq_len) # 每行的归一化分母
# 外层循环:遍历Q的块
for i in range(num_blocks):
# 加载Q的第i块到SRAM
Q_i = Q[i*block_size : (i+1)*block_size] # (block_size, head_dim)
# 初始化该块的输出和统计量
O_i = torch.zeros_like(Q_i)
m_i = torch.full((Q_i.shape[0],), -float('inf'))
ell_i = torch.zeros(Q_i.shape[0])
# 内层循环:遍历K、V的块
for j in range(num_blocks):
# 加载K、V的第j块到SRAM
K_j = K[j*block_size : (j+1)*block_size] # (block_size, head_dim)
V_j = V[j*block_size : (j+1)*block_size]
# 在SRAM中计算注意力分数
scores_ij = Q_i @ K_j.T / math.sqrt(head_dim)
# scores_ij: (Q_block_size, K_block_size)
# 更新统计量
m_i_new = torch.maximum(m_i, scores_ij.max(dim=1).values)
# 重新归一化之前的输出
correction_factor = torch.exp(m_i - m_i_new)
ell_i = ell_i * correction_factor
# 计算当前块的贡献
scores_ij_normalized = torch.exp(scores_ij - m_i_new.unsqueeze(1))
ell_i_new = ell_i + scores_ij_normalized.sum(dim=1)
# 更新输出(加权平均)
O_i = O_i * correction_factor.unsqueeze(1) + scores_ij_normalized @ V_j
# 更新统计量
m_i = m_i_new
ell_i = ell_i_new
# 最终归一化
O_i = O_i / ell_i.unsqueeze(1)
# 写回HBM
O[i*block_size : (i+1)*block_size] = O_i
return O
关键技术细节
1. 在线Softmax更新
假设已经处理了块 ,统计量为 ,现在处理块 :
# 新块的最大值
m_new = max(m_old, max(scores_k))
# 重新归一化之前的结果
correction = exp(m_old - m_new)
ell_old = ell_old * correction
O_old = O_old * correction
# 添加新块的贡献
scores_k_normalized = exp(scores_k - m_new)
ell_new = ell_old + sum(scores_k_normalized)
O_new = O_old + scores_k_normalized @ V_k
2. 不保存中间矩阵
标准注意力在训练时的行为:
在模型训练过程中,标准注意力实现必须保存注意力分数矩阵 :
def standard_attention_training(Q, K, V):
scores = Q @ K.T / sqrt(d_k)
attn_weights = softmax(scores) # 这个必须保存在显存中!
output = attn_weights @ V
# PyTorch会自动保存attn_weights,用于反向传播
# 因为计算梯度时需要:
# d(output)/dQ 依赖于 attn_weights
# d(output)/dK 依赖于 attn_weights
# d(output)/dV 依赖于 attn_weights
return output
为什么必须保存?
在反向传播时,根据链式法则计算梯度需要用到注意力权重:
因此,标准注意力必须在显存中保存 大小的注意力矩阵,这在长序列训练时会消耗大量显存。
FlashAttention的优化:
FlashAttention直接计算输出,不在显存中保存中间的注意力矩阵,只在SRAM中临时计算。
3. 反向传播时的重计算
既然FlashAttention没有保存注意力分数,那么反向传播时怎么办?解决方案:重新计算
def flash_attention_backward(dO, Q, K, V):
# dO: 输出的梯度 (从后续层传回来的)
# 需要计算 dQ, dK, dV
# 关键:我们没有保存attn_weights,所以需要重新计算
# 标准注意力的反向传播会直接使用保存的attn_weights
# FlashAttention选择重新计算,换取显存节省
# 前向重计算:再次分块计算注意力
for i in range(num_blocks):
Q_i = Q[i*block_size : (i+1)*block_size]
for j in range(num_blocks):
K_j = K[j*block_size : (j+1)*block_size]
V_j = V[j*block_size : (j+1)*block_size]
# 重新计算注意力分数(这就是"重计算")
scores_ij = Q_i @ K_j.T / sqrt(head_dim)
attn_ij = softmax(scores_ij) # 临时计算,不保存到显存
# 使用重新计算的attn_ij来计算梯度
dV_j = attn_ij.T @ dO_i # 计算V的梯度
dAttn = dO_i @ V_j.T # 计算attention的梯度
dScores = softmax_backward(dAttn, attn_ij) # Softmax的梯度
dQ_i += dScores @ K_j / sqrt(head_dim) # Q的梯度
dK_j += dScores.T @ Q_i / sqrt(head_dim) # K的梯度
权衡分析(Trade-off):
| 方法 | 前向显存占用 | 反向计算量 | HBM访问 | 总体性能 |
|---|---|---|---|---|
| 标准注意力 | 1倍 | 高 | 慢 | |
| FlashAttention | 2倍(重计算) | 低 | 快 |
为什么FlashAttention更快?
- 显存节省:,可以训练更长序列或更大batch
- HBM访问减少:减少10-100倍,这是主要瓶颈
- 计算量增加可以接受:现代GPU的计算能力远超显存带宽,多算一遍前向几乎不影响总时间
结论:虽然重计算增加了约2倍的计算量,但由于大幅减少了显存占用(10倍)和HBM访问(10倍),总体仍然比标准注意力快2-9倍。
训练 vs 推理:FlashAttention的不同优势
重要区分:FlashAttention在训练和推理场景下的优势是不同的。
训练场景
标准注意力:
# 前向传播
scores = Q @ K.T
attn_weights = softmax(scores) # 保存到显存(用于反向传播)
output = attn_weights @ V
# 反向传播
# 直接使用保存的 attn_weights 计算梯度
dV = attn_weights.T @ dO
dQ = ... # 使用 attn_weights
FlashAttention:
# 前向传播
output = flash_attention_forward(Q, K, V)
# 不保存 attn_weights,节省显存
# 反向传播
# 重新计算 attn_weights(重计算)
attn_weights = recompute_attention(Q, K, V)
dV = attn_weights.T @ dO
dQ = ...
训练时的优势:
- ✅ 显存节省:不保存 的中间矩阵,节省10-100倍显存
- ✅ HBM访问减少:减少10倍显存访问
- ⚠️ 计算量增加:反向传播时重新计算前向,增加约2倍计算量
- ✅ 总体更快:显存带宽是瓶颈,计算量增加影响小
推理场景
标准注意力:
# 推理时不需要反向传播
scores = Q @ K.T # 临时存储在HBM
attn_weights = softmax(scores) # 临时存储在HBM
output = attn_weights @ V
# 虽然不保存用于梯度,但计算过程仍需要临时HBM空间
# 仍然有大量的HBM读写
FlashAttention:
# 推理时同样不需要反向传播
output = flash_attention_forward(Q, K, V)
# 在SRAM中分块计算,几乎不使用HBM临时空间
推理时的优势:
- ✅ HBM访问减少:仍然减少10倍显存访问(主要优势)
- ✅ 临时显存减少:不需要分配 的临时空间
- ✅ 无重计算开销:推理不需要反向传播,没有重计算的概念
- ✅ 纯粹的性能提升:只有好处,没有权衡
关键点:
推理时,FlashAttention没有"重计算"的概念,因为根本不需要反向传播!
推理时FlashAttention的加速完全来自于:
- 减少HBM访问(主要)
- 更高效的内存访问模式
- 更好的SRAM利用
对比总结
| 场景 | 标准注意力 | FlashAttention | FlashAttention优势 |
|---|---|---|---|
| 训练(前向) | 计算并保存attn_weights | 分块计算,不保存 | 节省显存,减少HBM访问 |
| 训练(反向) | 使用保存的attn_weights | 重新计算attn_weights | 节省显存 > 计算量增加 |
| 推理 | 临时计算attn_weights | 分块计算,在SRAM中完成 | 减少HBM访问,无重计算开销 |
实际影响:
推理时的性能提升甚至比训练时更明显,因为:
- 训练时:有重计算的开销(虽然影响小)
- 推理时:没有任何额外开销,纯粹的加速
FlashAttention的性能提升
显存访问对比
| 方法 | HBM访问量 | 中间矩阵大小 |
|---|---|---|
| 标准注意力 | ||
| FlashAttention | (块大小) |
对于 :
- 标准: 个元素(16 MB)
- Flash: 个元素(1 MB),减少16倍
实际性能(A100 GPU)
| 序列长度 | 标准注意力 | FlashAttention | 加速比 |
|---|---|---|---|
| 512 | 0.5 ms | 0.3 ms | 1.7x |
| 1024 | 2.1 ms | 0.8 ms | 2.6x |
| 2048 | 8.5 ms | 2.0 ms | 4.3x |
| 4096 | 34 ms | 5.5 ms | 6.2x |
| 8192 | 142 ms | 15 ms | 9.5x |
结论:序列越长,FlashAttention优势越明显。
注意:以上是前向传播的性能数据。
- 推理时:加速比如上表所示,没有额外开销
- 训练时:总体加速比略低(因为反向传播有重计算),但仍然比标准注意力快2-5倍
- 训练时最大优势:显存节省10-100倍,可以训练更长序列或更大batch
FlashAttention的变体
FlashAttention-2
改进:
- 更好的并行化策略(按注意力头并行)
- 减少非矩阵乘法操作
- 更好的工作分配(避免线程空闲)
性能:比FlashAttention-1快2倍,比标准注意力快15倍。
FlashDecoding(专为推理优化)
应用场景:专门用于**推理(inference)**阶段的自回归生成,不涉及训练。
问题:推理时每次只生成1个Token(batch_size=1),并行度低。
解决方案:在序列维度并行,而不是批次维度。
# 标准:按批次并行
for b in parallel_batch:
output[b] = attention(Q[b], K[b], V[b])
# FlashDecoding:按序列并行(当batch=1时)
# 将K、V分成多个块,并行计算每个块的贡献
partial_outputs = parallel_for block in K_blocks:
scores = Q @ block.K.T
attn = softmax(scores)
partial = attn @ block.V
# 合并所有部分输出
output = merge(partial_outputs)
性能:单个Token生成速度提升8倍。
技术3:投机采样(Speculative Sampling)
自回归生成的低效率
问题:GPU利用率低
自回归生成的特点:
# 每次生成1个Token
tokens = [start_token]
for i in range(100):
next_token = model(tokens) # 模型前向传播
tokens.append(next_token) # 等待上一步完成
性能分析:
- 大模型(如LLaMA-70B):70B参数,每个Token需要加载140GB数据(FP16)
- A100 GPU内存带宽:1.5 TB/s
- 单个Token生成时间:140GB / 1.5TB/s ≈ 93ms
- 大部分时间在等待内存访问,GPU计算单元利用率<20%
结论:生成1个Token和生成多个Token的成本几乎一样(都是内存带宽瓶颈),但每次只生成1个Token。
投机采样的核心思想
用小模型快速生成多个候选Token,然后用大模型并行验证
类比理解:
- 传统方式:老师逐个批改作业(每次1份,批改100次)
- 投机采样:学生先自己写答案,老师一次性批改多份(批改10次,每次10份)
投机采样算法
步骤1:小模型生成候选序列
使用一个小而快的模型(draft model)生成 个候选Token:
# 小模型(例如:LLaMA-7B)
draft_model = load_small_model()
# 快速生成k个候选Token
candidates = []
for i in range(k):
next_token = draft_model(tokens + candidates)
candidates.append(next_token)
# 候选序列:[token_1, token_2, ..., token_k]
性能:小模型速度快(例如LLaMA-7B比LLaMA-70B快10倍)
步骤2:大模型并行验证
使用大模型(target model)并行验证所有候选Token:
# 大模型(例如:LLaMA-70B)
target_model = load_large_model()
# 并行验证:一次前向传播同时验证k个Token
# 输入:[token_0, candidate_1, candidate_2, ..., candidate_k]
logits = target_model(tokens + candidates) # (k+1, vocab_size)
# 对每个位置,计算大模型的概率分布
p_target = softmax(logits, dim=-1)
关键观察:
- 大模型可以并行处理所有候选Token(类似于训练时的teacher forcing)
- 计算 个Token的时间 ≈ 计算1个Token的时间(都是内存带宽瓶颈)
步骤3:接受或拒绝候选Token
从左到右逐个检查候选Token,根据概率分布决定是否接受:
def verify_candidates(candidates, p_draft, p_target):
"""
candidates: 候选Token列表
p_draft: 小模型的概率分布 (k, vocab_size)
p_target: 大模型的概率分布 (k, vocab_size)
"""
accepted = []
for i, token in enumerate(candidates):
# 计算接受概率
p_accept = min(1.0, p_target[i, token] / p_draft[i, token])
# 随机接受或拒绝
if random.random() < p_accept:
accepted.append(token)
else:
# 拒绝该Token,停止验证后续Token
# 从修正的分布中采样新Token
p_corrected = max(0, p_target[i] - p_draft[i])
p_corrected = p_corrected / p_corrected.sum()
new_token = sample(p_corrected)
accepted.append(new_token)
break
return accepted
接受条件:
直觉理解:
- 如果大模型也认为该Token概率高(),则接受
- 如果大模型认为该Token概率低,则以一定概率拒绝
完整流程示例
初始状态:
- 已生成:["今天", "天气"]
- 目标:继续生成
Step 1:小模型生成候选(k=4)
输入:["今天", "天气"]
小模型生成候选:["很", "不", "错", "呢"]
Step 2:大模型并行验证
输入:["今天", "天气", "很", "不", "错", "呢"]
大模型输出概率分布(简化):
- 位置0("今天"后):P_target("天气") = 0.6, P_draft("天气") = 0.5
- 位置1("天气"后):P_target("很") = 0.4, P_draft("很") = 0.3
- 位置2("很"后):P_target("不") = 0.1, P_draft("不") = 0.4
- 位置3("不"后):P_target("错") = 0.05, P_draft("错") = 0.3
Step 3:逐个验证
验证"很":
接受概率 = min(1, 0.4/0.3) = 1.0
随机数 = 0.5 < 1.0 → 接受✓
验证"不":
接受概率 = min(1, 0.1/0.4) = 0.25
随机数 = 0.8 > 0.25 → 拒绝✗
从修正分布中采样新Token → "好"
停止验证后续Token("错"、"呢")
最终接受:["很", "好"](2个Token)
投机采样的性能分析
加速比
假设:
- 候选序列长度:
- 平均接受率:(例如0.6,即60%的候选Token被接受)
- 小模型速度:(例如10 tokens/s)
- 大模型速度:(例如1 token/s)
每轮迭代:
- 时间成本:(小模型生成+大模型验证)
- 平均接受Token数:
有效速度:
实际例子():
加速比:倍
关键因素
1. 接受率 α
- α 越高,加速越明显
- α 取决于小模型和大模型的相似度
- 通常:α ∈ [0.5, 0.8]
2. 候选长度 k
- k 太小:加速不明显
- k 太大:接受率下降,浪费计算
- 最优值:通常 k ∈ [4, 10]
3. 小模型选择
- 越接近大模型,接受率越高
- 但小模型不能太大(否则速度慢)
- 实践:大模型的1/10参数量(如70B用7B)
实际性能数据
| 场景 | 标准生成 | 投机采样 | 加速比 |
|---|---|---|---|
| LLaMA-70B (7B draft, k=4) | 1.2 tokens/s | 2.1 tokens/s | 1.75x |
| GPT-3.5 (GPT-2 draft, k=6) | 15 tokens/s | 28 tokens/s | 1.87x |
| CodeLLaMA-34B (7B draft, k=5) | 2.5 tokens/s | 4.3 tokens/s | 1.72x |
投机采样的变体
1. Medusa:多头投机采样
问题:需要维护两个模型(大模型+小模型),增加部署复杂度。
解决方案:在大模型上添加多个轻量级"投机头",每个头预测未来某个位置的Token。
# 大模型的最后一层输出
hidden = large_model.forward(tokens) # (seq_len, hidden_dim)
# 添加多个投机头(每个头是一个轻量级MLP)
head_1 = mlp_head_1(hidden[-1]) # 预测下一个Token
head_2 = mlp_head_2(hidden[-1]) # 预测下下个Token
head_3 = mlp_head_3(hidden[-1]) # 预测下下下个Token
# 生成候选序列
candidates = [
sample(head_1),
sample(head_2),
sample(head_3)
]
# 大模型验证(与投机采样相同)
logits = large_model.forward(tokens + candidates)
# ...验证和接受
优势:
- 只需一个模型,部署简单
- 投机头可以并行计算,几乎无额外延迟
劣势:
- 需要额外训练投机头
- 接受率略低于使用独立小模型
2. 自投机采样(Self-Speculative Decoding)
思想:用大模型的早期退出(Early Exit)作为草稿模型。
# 大模型有L层
# 早期退出:只运行前L/2层
draft_output = large_model.forward(tokens, num_layers=L//2)
candidates = generate_candidates(draft_output)
# 完整运行:验证候选
target_output = large_model.forward(tokens + candidates, num_layers=L)
# ...验证
优势:
- 无需额外模型
- 早期层计算快
劣势:
- 早期层表达能力有限,接受率低
3. 批量投机采样
场景:同时处理多个请求(batch推理)。
挑战:不同请求的候选序列长度可能不同(有的接受4个,有的接受2个)。
解决方案:动态批处理
# 每个请求独立验证
for request in batch:
candidates = draft_model.generate(request, k=4)
accepted = verify(candidates)
request.append(accepted)
# 重新组织batch(根据接受的Token数)
# 继续下一轮
三大技术的协同作用
技术组合
在实际部署中,这三大技术通常组合使用:
┌─────────────────────────────────────────┐
│ PagedAttention │ 内存管理层
│ - 动态分配KV Cache页 │
│ - 内存利用率:20% → 90% │
└────────────┬────────────────────────────┘
│
┌────────────▼────────────────────────────┐
│ FlashAttention │ 计算优化层
│ - 分块计算,减少HBM访问 │
│ - 计算速度:提升2-9倍 │
└────────────┬────────────────────────────┘
│
┌────────────▼────────────────────────────┐
│ Speculative Sampling │ 算法优化层
│ - 并行验证多个Token │
│ - 生成速度:提升1.5-2倍 │
└─────────────────────────────────────────┘
实际系统性能
以vLLM(结合PagedAttention + FlashAttention)为例:
| 指标 | HuggingFace Transformers | vLLM | 提升 |
|---|---|---|---|
| 吞吐量(LLaMA-13B) | 0.3 req/s | 2.5 req/s | 8.3倍 |
| 内存利用率 | 20% | 90% | 4.5倍 |
| 最大batch size | 8 | 64 | 8倍 |
| 延迟(P50) | 2.1s | 0.8s | 2.6倍 |
再加上投机采样(Medusa):
| 指标 | vLLM | vLLM + Medusa | 额外提升 |
|---|---|---|---|
| 单请求延迟 | 800ms | 450ms | 1.8倍 |
| 吞吐量 | 2.5 req/s | 4.2 req/s | 1.7倍 |
总提升(相比原始实现):
- 吞吐量:8.3 × 1.7 = 14倍
- 内存效率:4.5倍
- 延迟:2.6 × 1.8 = 4.7倍
实际应用与最佳实践
技术选型建议
| 场景 | 推荐技术 | 原因 |
|---|---|---|
| 在线API服务(高并发) | PagedAttention + FlashAttention | 需要高吞吐量和内存效率 |
| 单用户聊天(低延迟) | FlashAttention + 投机采样 | 降低单请求延迟 |
| 长文本处理(>8K) | FlashAttention | 序列越长,加速越明显 |
| 资源受限(小显存) | PagedAttention | 提升内存利用率 |
| 代码生成 | 投机采样 | 代码模式性强,高接受率 |
开源实现
| 项目 | 技术 | 特点 |
|---|---|---|
| vLLM | PagedAttention + FlashAttention | 高吞吐量推理服务 |
| FlashAttention | FlashAttention-2 | 官方实现,性能最优 |
| Medusa | 多头投机采样 | 单模型投机采样 |
| TensorRT-LLM | FlashAttention + 多种优化 | NVIDIA官方,生产级 |
| Text Generation Inference | FlashAttention + 连续批处理 | HuggingFace官方 |
部署示例
vLLM部署(PagedAttention + FlashAttention)
from vllm import LLM, SamplingParams
# 初始化模型
llm = LLM(
model="meta-llama/Llama-2-13b-hf",
tensor_parallel_size=2, # 使用2个GPU
gpu_memory_utilization=0.95, # 高内存利用率(PagedAttention)
)
# 批量推理
prompts = ["Explain quantum computing", "Write a poem about AI"]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
自动启用:
- PagedAttention:自动管理KV Cache
- FlashAttention:自动使用(如果可用)
投机采样部署
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载大模型和小模型
target_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf")
draft_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# 投机采样生成
from transformers import SpeculativeDecoder
decoder = SpeculativeDecoder(
target_model=target_model,
draft_model=draft_model,
k=5 # 候选序列长度
)
output = decoder.generate(
input_ids,
max_new_tokens=100
)
小结
PagedAttention
核心思想:将KV Cache分页,按需分配,类似操作系统的虚拟内存。
优势:
- 内存利用率:20% → 90%(提升4.5倍)
- 支持动态批处理和内存共享
- 吞吐量提升5-10倍
适用场景:高并发API服务、批量推理
FlashAttention
核心思想:分块计算,在SRAM中完成尽可能多的操作,减少HBM访问。
优势:
- 显存访问:减少10-100倍
- 计算速度:提升2-9倍(序列越长越明显)
- 无精度损失(数学等价)
训练 vs 推理:
- 训练时:节省显存(不保存中间矩阵),但反向传播需要重计算,总体仍更快
- 推理时:无重计算开销,纯粹的性能提升,加速更明显
适用场景:
- 训练:长序列训练、显存受限场景
- 推理:所有推理场景(尤其是长文本)
投机采样
核心思想:用小模型快速生成候选,大模型并行验证。
优势:
- 生成速度:提升1.5-2倍
- 保持输出质量(与原模型等价)
- 无需修改模型结构
适用场景:低延迟生成、代码生成、模式性强的任务
综合效应
三大技术组合使用,可实现:
- 吞吐量:提升10-15倍
- 延迟:降低3-5倍
- 成本:降低50-70%(更少GPU,更高利用率)
这三大技术是大模型推理优化的基石,几乎所有现代推理系统都基于它们构建。理解这些技术的原理,对于优化部署、降低成本、提升用户体验至关重要。