14-PagedAttention、FlashAttention与投机采样:推理优化三大技术

5 阅读7分钟

PagedAttention、FlashAttention与投机采样:推理优化三大技术

大模型推理的三大瓶颈

在上一章中,我们学习了KV Cache如何通过缓存已计算的K和V来加速推理。但即使有了KV Cache,大模型推理仍然面临三个核心瓶颈:

瓶颈1:内存管理效率低(PagedAttention解决)

问题:传统的KV Cache需要预先分配连续内存块

# 传统方式:预分配最大长度的内存
max_seq_len = 4096
k_cache = torch.zeros(batch_size, max_seq_len, num_heads, head_dim)
v_cache = torch.zeros(batch_size, max_seq_len, num_heads, head_dim)

带来的问题

  • 即使实际只生成100个Token,也要分配4096个Token的空间
  • 内存利用率低(浪费率可达60-80%)
  • 无法支持动态长度的批处理

瓶颈2:显存访问效率低(FlashAttention解决)

问题:标准注意力计算需要大量的显存读写

注意力计算公式:

Output=softmax(QKTdk)V\text{Output} = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V

传统实现的显存访问:

# 步骤1:计算注意力分数(写入HBM)
scores = Q @ K.T  # 形状:(seq_len, seq_len),写入显存

# 步骤2:Softmax(从HBM读取,写回HBM)
attn_weights = softmax(scores / sqrt(d_k))  # 读取scores,写回显存

# 步骤3:加权求和(从HBM读取V和attn_weights)
output = attn_weights @ V  # 读取attn_weights和V

问题:中间结果(scores、attn_weights)频繁读写显存,显存带宽成为瓶颈。

对于序列长度 n=2048n=2048,中间矩阵 scoresR2048×2048\text{scores} \in \mathbb{R}^{2048 \times 2048} 占用 16MB(FP16),需要反复读写。

瓶颈3:自回归生成速度慢(投机采样解决)

问题:每次只生成1个Token,GPU利用率低

# 自回归生成:串行处理
for i in range(100):
    next_token = model(input_tokens)  # 生成1个Token
    input_tokens.append(next_token)   # 等待上一步完成

带来的问题

  • GPU大部分时间在等待内存访问
  • 计算单元利用率低(<20%)
  • 生成100个Token需要100次模型前向传播

本章将详细讲解这三大优化技术如何解决这些问题。

技术1:PagedAttention - 虚拟内存管理

核心思想

PagedAttention借鉴操作系统的虚拟内存思想,将KV Cache分成固定大小的"页(Page)",按需分配。

类比理解

  • 传统KV Cache:像酒店提前包下整层楼(100间房),即使只住10个人
  • PagedAttention:像酒店按需分配房间,只分配实际需要的房间数

传统KV Cache的问题

假设我们要处理一个batch的请求:

请求ID实际长度分配内存利用率
请求1100 Token4096 Token2.4%
请求2500 Token4096 Token12.2%
请求350 Token4096 Token1.2%
请求4200 Token4096 Token4.9%

总内存:4 × 4096 × 层数 × 头数 × 头维度 实际使用:(100+500+50+200) × 层数 × 头数 × 头维度 利用率:850 / (4×4096) = 5.2%

PagedAttention的解决方案

步骤1:将KV Cache分页

将连续的KV Cache分成固定大小的页(例如每页16个Token):

# 页的大小(每页存储多少个Token)
PAGE_SIZE = 16

# 传统KV Cache:连续内存
k_cache_traditional = torch.zeros(batch, max_seq_len, num_heads, head_dim)

# PagedAttention:分页存储
# 每一页是一个独立的tensor
k_cache_paged = [
    torch.zeros(batch, PAGE_SIZE, num_heads, head_dim)
    for _ in range(max_seq_len // PAGE_SIZE)
]
步骤2:维护页表(Page Table)

为每个请求维护一个页表,记录其KV Cache使用了哪些页:

class PageTable:
    def __init__(self):
        self.pages = []  # 存储该请求使用的物理页索引

    def allocate_page(self, physical_page_id):
        """分配一个新页"""
        self.pages.append(physical_page_id)

    def get_physical_page(self, virtual_page_id):
        """将虚拟页号转换为物理页号"""
        return self.pages[virtual_page_id]

# 每个请求有自己的页表
request_1_page_table = PageTable()
request_2_page_table = PageTable()
步骤3:按需分配页

当请求生成新Token时,动态分配页:

class PagedKVCache:
    def __init__(self, num_layers, num_heads, head_dim, page_size=16):
        self.page_size = page_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim

        # 物理内存池:所有空闲的页
        self.free_pages = []
        self.used_pages = {}  # page_id -> tensor

        # 为每个请求维护页表
        self.page_tables = {}  # request_id -> PageTable

    def allocate_page(self):
        """从内存池中分配一个空闲页"""
        if self.free_pages:
            page_id = self.free_pages.pop()
        else:
            # 创建新页
            page_id = len(self.used_pages)
            self.used_pages[page_id] = torch.zeros(
                self.num_layers, self.page_size, self.num_heads, self.head_dim
            )
        return page_id

    def append_kv(self, request_id, layer_id, k_new, v_new):
        """
        为请求添加新的K、V
        k_new, v_new: (1, num_heads, head_dim) - 单个Token的K、V
        """
        if request_id not in self.page_tables:
            # 新请求:创建页表并分配第一页
            self.page_tables[request_id] = PageTable()
            page_id = self.allocate_page()
            self.page_tables[request_id].allocate_page(page_id)

        page_table = self.page_tables[request_id]
        num_tokens = len(page_table.pages) * self.page_size

        # 计算当前Token在哪一页
        page_idx = num_tokens // self.page_size
        offset_in_page = num_tokens % self.page_size

        # 如果当前页已满,分配新页
        if offset_in_page == 0 and num_tokens > 0:
            page_id = self.allocate_page()
            page_table.allocate_page(page_id)
            page_idx = len(page_table.pages) - 1

        # 获取物理页
        physical_page_id = page_table.get_physical_page(page_idx)
        physical_page = self.used_pages[physical_page_id]

        # 写入K、V
        physical_page[layer_id, offset_in_page] = k_new

    def get_kv(self, request_id, layer_id):
        """
        获取请求的所有K、V
        返回:(seq_len, num_heads, head_dim)
        """
        page_table = self.page_tables[request_id]
        k_list = []

        for page_idx in range(len(page_table.pages)):
            physical_page_id = page_table.get_physical_page(page_idx)
            physical_page = self.used_pages[physical_page_id]

            # 读取该页的K
            k_page = physical_page[layer_id]  # (page_size, num_heads, head_dim)
            k_list.append(k_page)

        # 拼接所有页
        k_all = torch.cat(k_list, dim=0)  # (seq_len, num_heads, head_dim)
        return k_all

    def free_request(self, request_id):
        """释放请求占用的所有页"""
        if request_id in self.page_tables:
            page_table = self.page_tables[request_id]
            # 将所有页归还到空闲池
            self.free_pages.extend(page_table.pages)
            del self.page_tables[request_id]

PagedAttention的优势

1. 内存利用率提升

示例:处理4个请求,实际长度分别为100、500、50、200

传统方式

  • 分配内存:4 × 4096 = 16384 个Token位置
  • 实际使用:100 + 500 + 50 + 200 = 850 个Token位置
  • 利用率:5.2%

PagedAttention(页大小=16)

  • 请求1:需要 100/16 ≈ 7页
  • 请求2:需要 500/16 ≈ 32页
  • 请求3:需要 50/16 ≈ 4页
  • 请求4:需要 200/16 ≈ 13页
  • 总页数:56页
  • 分配内存:56 × 16 = 896 个Token位置
  • 实际使用:850 个Token位置
  • 利用率:95%(提升18倍!)
2. 支持动态批处理

传统方式必须等待所有请求完成才能释放内存,PagedAttention可以:

# 请求1完成,立即释放内存
paged_cache.free_request(request_1)

# 立即用释放的内存处理新请求
paged_cache.append_kv(request_5, layer_id, k_new, v_new)
3. 内存共享

对于共享前缀的请求,可以共享页:

# 请求1:"解释一下深度学习"
# 请求2:"解释一下深度学习的原理"

# 前缀"解释一下深度学习"的KV Cache可以共享
prefix_pages = request_1_page_table.pages[:3]  # 前3页
request_2_page_table.pages = prefix_pages + [new_page]  # 共享前缀,只分配新页

应用场景

  • 批量处理相似问题
  • Few-shot Learning(示例前缀相同)
  • 系统提示词(所有请求共享)

实际性能提升

根据vLLM(实现了PagedAttention)的论文数据:

指标传统KV CachePagedAttention提升
内存利用率20%90%4.5倍
吞吐量(请求/秒)0.52.34.6倍
支持的最大batch size8648倍

技术2:FlashAttention - 显存访问优化

GPU内存层次结构

在理解FlashAttention之前,需要深入了解GPU的硬件架构和内存系统。

GPU整体架构

现代GPU(以NVIDIA A100为例)的基本结构:

┌──────────────────────────────────────────────────────────────┐
│                      GPU芯片                                  │
│                                                               │
│  ┌─────────────────────────────────────────────────────┐    │
│  │ SM (Streaming Multiprocessor) 1                      │    │
│  │ ┌──────────┐  ┌──────────┐  ┌─────────────────┐   │    │
│  │ │  CUDA    │  │  CUDA    │  │  Shared Memory  │   │    │
│  │ │  Cores   │  │  Cores   │  │    (SRAM)       │   │    │
│  │ │ (计算单元)│  │ (计算单元)│  │   ~200 KB       │   │    │
│  │ └──────────┘  └──────────┘  └─────────────────┘   │    │
│  │ ┌────────────────────────────────────────────┐     │    │
│  │ │         Registers (寄存器)                 │     │    │
│  │ │         ~256 KB                            │     │    │
│  │ └────────────────────────────────────────────┘     │    │
│  └─────────────────────────────────────────────────────┘    │
│                                                               │
│  ┌─────────────────────────────────────────────────────┐    │
│  │ SM 2 (类似结构)                                      │    │
│  └─────────────────────────────────────────────────────┘    │
│                                                               │
│  ...  (A100有108个SM)                                        │
│                                                               │
│  ┌─────────────────────────────────────────────────────┐    │
│  │         L2 Cache (40 MB)                             │    │
│  └─────────────────────────────────────────────────────┘    │
│                      ↓↑                                       │
└──────────────────────│────────────────────────────────────────┘
                       │
                   内存总线
                       │
         ┌─────────────▼──────────────┐
         │  HBM (High Bandwidth Memory) │
         │  (显存: 40GB/80GB)           │
         └──────────────────────────────┘

关键组件

  1. SM (Streaming Multiprocessor):GPU的核心计算单元

    • A100有108个SM
    • 每个SM包含多个CUDA核心(计算单元)
    • 每个SM有自己的SRAM(共享内存和寄存器)
  2. CUDA Core:实际执行计算的单元

    • 类似于CPU的核心
    • 一个SM包含64个CUDA核心
  3. L2 Cache:所有SM共享的缓存

    • 40 MB(A100)
    • 作为SRAM和HBM之间的中间层
HBM (High Bandwidth Memory) - 显存

什么是HBM?

HBM是GPU的主存储器,相当于计算机的RAM,但专为GPU设计。

物理特性

HBM芯片结构:
┌──────────────────┐
│   GPU Die        │  ← GPU核心芯片
├──────────────────┤
│  Interposer      │  ← 硅中介层(高速连接)
├──────────────────┤
│ HBM Stack 1      │  ← 内存堆叠芯片
│ HBM Stack 2      │
│ HBM Stack 3      │
│ HBM Stack 4      │
└──────────────────┘

特点:
- 垂直堆叠多层DRAM芯片
- 通过TSV (Through-Silicon Via) 连接
- 极宽的总线 (5120-bit on A100)

HBM的特点

  1. 容量大

    • A100:40GB/80GB
    • H100:80GB
    • 存储模型参数、激活值、KV Cache等
  2. 带宽高(相对于传统GDDR):

    • A100 HBM2e:1.5-2.0 TB/s
    • H100 HBM3:3.0 TB/s
    • 比传统GDDR快2-3倍
  3. 延迟相对高

    • 访问延迟:~500-800纳秒
    • 相比SRAM慢10-100倍
  4. 物理位置

    • 在GPU芯片外部(但在同一封装内)
    • 需要通过内存总线访问

HBM存储的内容

HBM中存储的数据(AI训练/推理):
1. 模型参数(W, b):几GB到几十GB
2. 激活值(中间结果):几GB
3. KV Cache(推理时):几GB
4. 梯度(训练时):与参数量相当
5. 优化器状态(训练时):参数量的2-3倍

例如:LLaMA-7B推理
- 模型参数:14 GB (FP16)
- KV Cache:~4 GB (batch=8, seq_len=2048)
- 激活值:~2 GB
总计:~20 GB HBM
SRAM (Static RAM) - 片上内存

什么是SRAM?

SRAM是集成在GPU芯片内部的高速缓存,直接在计算单元旁边。

物理特性

SRAM与HBM的对比:

┌─────────────────────────┐
│    GPU芯片内部           │
│  ┌──────┐  ┌────────┐  │
│  │ CUDA │←→│ SRAM   │  │  ← 片上,极近
│  │ Core │  │(100KB) │  │     延迟: 几纳秒
│  └──────┘  └────────┘  │
└─────────────────────────┘
          ↓↑
      内存总线 (较慢)
          ↓↑
┌─────────────────────────┐
│       HBM (40GB)        │  ← 片外(同封装)
└─────────────────────────┘     延迟: 几百纳秒

SRAM的类型

  1. Shared Memory(共享内存)

    • 每个SM有一块共享内存
    • A100:每个SM约164 KB
    • 可以被该SM内的所有线程访问
    • 程序员可以显式控制
  2. Registers(寄存器)

    • 每个SM有大量寄存器
    • A100:每个SM约256 KB
    • 每个线程独占自己的寄存器
    • 访问速度最快
  3. L1 Cache

    • 与Shared Memory共享物理空间
    • 硬件自动管理
    • 可配置大小比例

SRAM的特点

  1. 速度极快

    • 访问延迟:1-2纳秒
    • 带宽:A100每个SM约19 TB/s
    • 比HBM快10-100倍
  2. 容量极小

    • 每个SM:~200 KB(Shared Memory)
    • 全GPU:108个SM × 200KB ≈ 20 MB
    • 比HBM小2000倍(40GB vs 20MB)
  3. 物理位置

    • 在GPU芯片内部
    • 紧邻计算单元
    • 无需经过内存总线

SRAM使用的数据

SRAM中存储的数据(临时):
1. 正在计算的小块数据
2. 循环变量、临时结果
3. 频繁访问的数据

例如:FlashAttention
- Q的一块:64 × 128 × 2 bytes = 16 KB
- K的一块:64 × 128 × 2 bytes = 16 KB
- V的一块:64 × 128 × 2 bytes = 16 KB
- 中间结果:64 × 64 × 2 bytes = 8 KB
总计:~56 KB(远小于200KB限制)✓
内存层次对比
┌──────────────┬────────────┬──────────┬───────────┬─────────────┐
│   内存类型   │   容量     │  带宽    │  延迟     │   物理位置  │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ Registers    │ ~256 KB/SM │ 极高     │ 1 ns      │ SM内部      │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ Shared Mem   │ ~200 KB/SM │ 19 TB/s  │ 1-2 ns    │ SM内部      │
│ (SRAM)       │ 全GPU~20MB │ /SM      │           │             │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ L2 Cache     │ 40 MB      │ ~7 TB/s  │ ~50 ns    │ GPU芯片内   │
├──────────────┼────────────┼──────────┼───────────┼─────────────┤
│ HBM          │ 40-80 GB   │ 1.5 TB/s │ 500-800ns │ 芯片外      │
│ (显存)       │            │          │           │ (同封装内)  │
└──────────────┴────────────┴──────────┴───────────┴─────────────┘

速度对比:
Registers ≈ Shared Memory > L2 Cache >>> HBM
   1倍         1倍              3倍        12倍慢

关键数据(A100 GPU):

  • SRAM带宽19 TB/s (每个SM)
  • HBM带宽1.5 TB/s (整个GPU)
  • 速度差距:SRAM比HBM快12倍以上
  • 容量差距:HBM比SRAM大2000倍以上
为什么速度差异如此大?

物理距离

数据传输距离(简化示意):

SRAM访问:
CUDA Core → 几微米 → SRAM
传输时间:~1纳秒

HBM访问:
CUDA Core → 几毫米 → 内存控制器 → 几毫米 → HBM
传输时间:~500纳秒

差距来源:
1. 物理距离:500倍
2. 总线宽度:SRAM更宽
3. 访问机制:SRAM更简单
4. 电气特性:SRAM电压切换更快

类比理解

CPU/GPU内存系统 ≈ 厨房

Registers/SRAM(片上内存):
= 厨师手里的锅和案板
- 极快(伸手就拿到)
- 容量小(只能放当前用的食材)

L2 Cache:
= 灶台旁边的调料架
- 较快(转身就拿到)
- 容量中等

HBM(显存):
= 厨房的冰箱
- 较慢(走几步才能拿到)
- 容量大(存放所有食材)

主内存(CPU RAM):
= 储藏室
- 很慢(要出厨房)
- 容量更大

硬盘:
= 超市
- 极慢(要出门)
- 容量巨大
FlashAttention的核心策略

理解了GPU内存层次后,FlashAttention的策略就很清楚了:

问题

# 标准注意力
scores = Q @ K.T          # 2048×2048矩阵 → 写入HBM(慢)
attn = softmax(scores)     # 从HBM读取 → 计算 → 写回HBM(慢)
output = attn @ V          # 从HBM读取 → 计算(慢)

# 瓶颈:在HBM和SRAM之间来回传输大矩阵

FlashAttention的解决方案

# 分块加载到SRAM(快)
for Q_block in Q.chunks():           # 64×128 加载到SRAM
    for K_block, V_block in zip(...): # 64×128 加载到SRAM
        scores_block = Q_block @ K_block.T  # 在SRAM内计算(快)
        attn_block = softmax(scores_block)   # 在SRAM内计算(快)
        output += attn_block @ V_block       # 在SRAM内计算(快)
        # 只在循环结束时写回HBM

# 关键:大部分计算在SRAM内完成,HBM访问极少

结论:尽量在SRAM中完成计算,减少HBM访问次数,这就是FlashAttention快的根本原因。

标准注意力的显存访问

标准实现
def standard_attention(Q, K, V):
    # Q, K, V: (batch, seq_len, num_heads, head_dim)
    seq_len = Q.shape[1]

    # 步骤1:计算注意力分数(O(n²) 空间)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(head_dim)
    # scores: (batch, seq_len, num_heads, seq_len)
    # 写入HBM:seq_len² × batch × num_heads 个float

    # 步骤2:Softmax
    attn_weights = F.softmax(scores, dim=-1)
    # 从HBM读取scores,写回HBM:2次HBM访问
    # 【重要】在训练时,attn_weights会被保存在显存中用于反向传播!

    # 步骤3:加权求和
    output = attn_weights @ V
    # 从HBM读取attn_weights和V:2次HBM读取

    return output
    # 训练时:scores和attn_weights都会被PyTorch保存在显存中
    # 推理时:不需要保存(因为不需要反向传播)
显存访问分析

对于序列长度 nn,注意力头数 hh,头维度 dd

操作HBM读取HBM写入总访问量
QKTQ \cdot K^T2nhd2n \cdot h \cdot dn2hn^2 \cdot h2nhd+n2h2nhd + n^2h
Softmaxn2hn^2 \cdot hn2hn^2 \cdot h2n2h2n^2h
attnV\text{attn} \cdot Vn2h+nhdn^2 \cdot h + n \cdot h \cdot dnhdn \cdot h \cdot dn2h+2nhdn^2h + 2nhd

总HBM访问量O(n2h+nhd)O(n^2 h + nhd)

对于 n=2048,h=32,d=128n=2048, h=32, d=128

  • 中间矩阵 scores:20482×32×2bytes=256MB2048^2 \times 32 \times 2 \text{bytes} = 256 \text{MB}
  • 注意力权重 attn_weights:20482×32×2bytes=256MB2048^2 \times 32 \times 2 \text{bytes} = 256 \text{MB}
  • 需要多次读写,总访问量 > 1 GB

问题:HBM访问成为瓶颈,计算单元在等待数据。

训练时的显存占用

标准注意力在训练时需要保存的中间结果

# 前向传播保存的中间结果(用于反向传播)
scores = Q @ K.T         # O(n²) - 必须保存
attn_weights = softmax(scores)  # O(n²) - 必须保存
output = attn_weights @ V       # 输出

# 显存占用(单层,单头):
# - scores: n² × 2 bytes (FP16)
# - attn_weights: n² × 2 bytes (FP16)
# - 总计:2n² × 2 bytes

实际例子n=2048n=2048,32层,32头):

总显存=2×20482×2 bytes×32×32=32 GB\text{总显存} = 2 \times 2048^2 \times 2 \text{ bytes} \times 32 \text{层} \times 32 \text{头} = 32 \text{ GB}

仅注意力的中间结果就占用32GB!这还不包括模型参数、梯度、优化器状态等。

推理时的显存占用

推理时不需要反向传播,因此不需要保存中间结果,显存占用大幅降低。但即使如此,仍然需要多次HBM访问来读写这些中间矩阵。

FlashAttention的核心思想

在SRAM中完成尽可能多的计算,减少HBM访问次数

关键技术

  1. 分块(Tiling):将Q、K、V分成小块,每次只加载一块到SRAM
  2. 重计算(Recomputation):训练时不保存中间结果,反向传播时重新计算(推理无此开销)
  3. 在线Softmax:逐块计算Softmax,无需保存完整的注意力矩阵

FlashAttention算法详解

核心挑战:如何逐块计算Softmax?

标准Softmax需要全局信息:

softmax(xi)=exp(xi)j=1nexp(xj)\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_{j=1}^{n} \exp(x_j)}

问题:分母需要所有元素的和,但我们每次只加载一部分数据。

解决方案:在线更新Softmax统计量

m=max(x1,x2,,xn)(全局最大值)=i=1nexp(xim)(归一化分母)softmax(xi)=exp(xim)\begin{aligned} m &= \max(x_1, x_2, \ldots, x_n) \quad \text{(全局最大值)} \\ \ell &= \sum_{i=1}^{n} \exp(x_i - m) \quad \text{(归一化分母)} \\ \text{softmax}(x_i) &= \frac{\exp(x_i - m)}{\ell} \end{aligned}

关键观察:当读取新的数据块时,可以增量更新 mm\ell

算法步骤

输入

  • Q,K,VRn×dQ, K, V \in \mathbb{R}^{n \times d}(序列长度 nn,维度 dd
  • 块大小:BrB_r (Q的块大小),BcB_c (K、V的块大小)

输出

  • ORn×dO \in \mathbb{R}^{n \times d}(注意力输出)

算法流程

def flash_attention(Q, K, V, block_size=64):
    """
    Q, K, V: (seq_len, head_dim)
    block_size: SRAM块大小
    """
    seq_len, head_dim = Q.shape
    num_blocks = (seq_len + block_size - 1) // block_size

    # 输出和统计量(在HBM中,但会逐块更新)
    O = torch.zeros_like(Q)
    m = torch.full((seq_len,), -float('inf'))  # 每行的最大值
    ell = torch.zeros(seq_len)                 # 每行的归一化分母

    # 外层循环:遍历Q的块
    for i in range(num_blocks):
        # 加载Q的第i块到SRAM
        Q_i = Q[i*block_size : (i+1)*block_size]  # (block_size, head_dim)

        # 初始化该块的输出和统计量
        O_i = torch.zeros_like(Q_i)
        m_i = torch.full((Q_i.shape[0],), -float('inf'))
        ell_i = torch.zeros(Q_i.shape[0])

        # 内层循环:遍历K、V的块
        for j in range(num_blocks):
            # 加载K、V的第j块到SRAM
            K_j = K[j*block_size : (j+1)*block_size]  # (block_size, head_dim)
            V_j = V[j*block_size : (j+1)*block_size]

            # 在SRAM中计算注意力分数
            scores_ij = Q_i @ K_j.T / math.sqrt(head_dim)
            # scores_ij: (Q_block_size, K_block_size)

            # 更新统计量
            m_i_new = torch.maximum(m_i, scores_ij.max(dim=1).values)

            # 重新归一化之前的输出
            correction_factor = torch.exp(m_i - m_i_new)
            ell_i = ell_i * correction_factor

            # 计算当前块的贡献
            scores_ij_normalized = torch.exp(scores_ij - m_i_new.unsqueeze(1))
            ell_i_new = ell_i + scores_ij_normalized.sum(dim=1)

            # 更新输出(加权平均)
            O_i = O_i * correction_factor.unsqueeze(1) + scores_ij_normalized @ V_j

            # 更新统计量
            m_i = m_i_new
            ell_i = ell_i_new

        # 最终归一化
        O_i = O_i / ell_i.unsqueeze(1)

        # 写回HBM
        O[i*block_size : (i+1)*block_size] = O_i

    return O
关键技术细节

1. 在线Softmax更新

假设已经处理了块 j=1,2,,k1j=1, 2, \ldots, k-1,统计量为 (mold,old)(m_{old}, \ell_{old}),现在处理块 kk

# 新块的最大值
m_new = max(m_old, max(scores_k))

# 重新归一化之前的结果
correction = exp(m_old - m_new)
ell_old = ell_old * correction
O_old = O_old * correction

# 添加新块的贡献
scores_k_normalized = exp(scores_k - m_new)
ell_new = ell_old + sum(scores_k_normalized)
O_new = O_old + scores_k_normalized @ V_k

2. 不保存中间矩阵

标准注意力在训练时的行为

在模型训练过程中,标准注意力实现必须保存注意力分数矩阵 attn_weightsRn×n\text{attn\_weights} \in \mathbb{R}^{n \times n}

def standard_attention_training(Q, K, V):
    scores = Q @ K.T / sqrt(d_k)
    attn_weights = softmax(scores)  # 这个必须保存在显存中!
    output = attn_weights @ V

    # PyTorch会自动保存attn_weights,用于反向传播
    # 因为计算梯度时需要:
    # d(output)/dQ 依赖于 attn_weights
    # d(output)/dK 依赖于 attn_weights
    # d(output)/dV 依赖于 attn_weights

    return output

为什么必须保存

在反向传播时,根据链式法则计算梯度需要用到注意力权重:

LV=attn_weightsTLoutputLQ 和 LK 也都依赖于 attn_weights\begin{aligned} \frac{\partial \mathcal{L}}{\partial V} &= \text{attn\_weights}^T \cdot \frac{\partial \mathcal{L}}{\partial \text{output}} \\ \frac{\partial \mathcal{L}}{\partial Q} &\text{ 和 } \frac{\partial \mathcal{L}}{\partial K} \text{ 也都依赖于 attn\_weights} \end{aligned}

因此,标准注意力必须在显存中保存 O(n2)O(n^2) 大小的注意力矩阵,这在长序列训练时会消耗大量显存。

FlashAttention的优化

FlashAttention直接计算输出,不在显存中保存中间的注意力矩阵,只在SRAM中临时计算。

3. 反向传播时的重计算

既然FlashAttention没有保存注意力分数,那么反向传播时怎么办?解决方案:重新计算

def flash_attention_backward(dO, Q, K, V):
    # dO: 输出的梯度 (从后续层传回来的)
    # 需要计算 dQ, dK, dV

    # 关键:我们没有保存attn_weights,所以需要重新计算
    # 标准注意力的反向传播会直接使用保存的attn_weights
    # FlashAttention选择重新计算,换取显存节省

    # 前向重计算:再次分块计算注意力
    for i in range(num_blocks):
        Q_i = Q[i*block_size : (i+1)*block_size]
        for j in range(num_blocks):
            K_j = K[j*block_size : (j+1)*block_size]
            V_j = V[j*block_size : (j+1)*block_size]

            # 重新计算注意力分数(这就是"重计算")
            scores_ij = Q_i @ K_j.T / sqrt(head_dim)
            attn_ij = softmax(scores_ij)  # 临时计算,不保存到显存

            # 使用重新计算的attn_ij来计算梯度
            dV_j = attn_ij.T @ dO_i  # 计算V的梯度
            dAttn = dO_i @ V_j.T     # 计算attention的梯度
            dScores = softmax_backward(dAttn, attn_ij)  # Softmax的梯度
            dQ_i += dScores @ K_j / sqrt(head_dim)  # Q的梯度
            dK_j += dScores.T @ Q_i / sqrt(head_dim)  # K的梯度

权衡分析(Trade-off)

方法前向显存占用反向计算量HBM访问总体性能
标准注意力O(n2)O(n^2)1倍
FlashAttentionO(n)O(n)2倍(重计算)

为什么FlashAttention更快

  1. 显存节省O(n2)O(n)O(n^2) \rightarrow O(n),可以训练更长序列或更大batch
  2. HBM访问减少:减少10-100倍,这是主要瓶颈
  3. 计算量增加可以接受:现代GPU的计算能力远超显存带宽,多算一遍前向几乎不影响总时间

结论:虽然重计算增加了约2倍的计算量,但由于大幅减少了显存占用(10倍)和HBM访问(10倍),总体仍然比标准注意力快2-9倍。

训练 vs 推理:FlashAttention的不同优势

重要区分:FlashAttention在训练和推理场景下的优势是不同的。

训练场景

标准注意力

# 前向传播
scores = Q @ K.T
attn_weights = softmax(scores)  # 保存到显存(用于反向传播)
output = attn_weights @ V

# 反向传播
# 直接使用保存的 attn_weights 计算梯度
dV = attn_weights.T @ dO
dQ = ...  # 使用 attn_weights

FlashAttention

# 前向传播
output = flash_attention_forward(Q, K, V)
# 不保存 attn_weights,节省显存

# 反向传播
# 重新计算 attn_weights(重计算)
attn_weights = recompute_attention(Q, K, V)
dV = attn_weights.T @ dO
dQ = ...

训练时的优势

  1. 显存节省:不保存 O(n2)O(n^2) 的中间矩阵,节省10-100倍显存
  2. HBM访问减少:减少10倍显存访问
  3. ⚠️ 计算量增加:反向传播时重新计算前向,增加约2倍计算量
  4. 总体更快:显存带宽是瓶颈,计算量增加影响小
推理场景

标准注意力

# 推理时不需要反向传播
scores = Q @ K.T          # 临时存储在HBM
attn_weights = softmax(scores)  # 临时存储在HBM
output = attn_weights @ V

# 虽然不保存用于梯度,但计算过程仍需要临时HBM空间
# 仍然有大量的HBM读写

FlashAttention

# 推理时同样不需要反向传播
output = flash_attention_forward(Q, K, V)
# 在SRAM中分块计算,几乎不使用HBM临时空间

推理时的优势

  1. HBM访问减少:仍然减少10倍显存访问(主要优势)
  2. 临时显存减少:不需要分配 O(n2)O(n^2) 的临时空间
  3. 无重计算开销:推理不需要反向传播,没有重计算的概念
  4. 纯粹的性能提升:只有好处,没有权衡

关键点

推理时,FlashAttention没有"重计算"的概念,因为根本不需要反向传播!

推理时FlashAttention的加速完全来自于:

  • 减少HBM访问(主要)
  • 更高效的内存访问模式
  • 更好的SRAM利用
对比总结
场景标准注意力FlashAttentionFlashAttention优势
训练(前向)计算并保存attn_weights分块计算,不保存节省显存,减少HBM访问
训练(反向)使用保存的attn_weights重新计算attn_weights节省显存 > 计算量增加
推理临时计算attn_weights分块计算,在SRAM中完成减少HBM访问,无重计算开销

实际影响

推理时的性能提升甚至比训练时更明显,因为:

  • 训练时:有重计算的开销(虽然影响小)
  • 推理时:没有任何额外开销,纯粹的加速

FlashAttention的性能提升

显存访问对比
方法HBM访问量中间矩阵大小
标准注意力O(n2+nd)O(n^2 + nd)O(n2)O(n^2)
FlashAttentionO(nd)O(nd)O(n)O(n) (块大小)

对于 n=2048,d=128n=2048, d=128

  • 标准:20482=4M2048^2 = 4M 个元素(16 MB)
  • Flash:2048×128=256K2048 \times 128 = 256K 个元素(1 MB),减少16倍
实际性能(A100 GPU)
序列长度标准注意力FlashAttention加速比
5120.5 ms0.3 ms1.7x
10242.1 ms0.8 ms2.6x
20488.5 ms2.0 ms4.3x
409634 ms5.5 ms6.2x
8192142 ms15 ms9.5x

结论:序列越长,FlashAttention优势越明显。

注意:以上是前向传播的性能数据。

  • 推理时:加速比如上表所示,没有额外开销
  • 训练时:总体加速比略低(因为反向传播有重计算),但仍然比标准注意力快2-5倍
  • 训练时最大优势:显存节省10-100倍,可以训练更长序列或更大batch

FlashAttention的变体

FlashAttention-2

改进

  1. 更好的并行化策略(按注意力头并行)
  2. 减少非矩阵乘法操作
  3. 更好的工作分配(避免线程空闲)

性能:比FlashAttention-1快2倍,比标准注意力快15倍。

FlashDecoding(专为推理优化)

应用场景:专门用于**推理(inference)**阶段的自回归生成,不涉及训练。

问题:推理时每次只生成1个Token(batch_size=1),并行度低。

解决方案:在序列维度并行,而不是批次维度。

# 标准:按批次并行
for b in parallel_batch:
    output[b] = attention(Q[b], K[b], V[b])

# FlashDecoding:按序列并行(当batch=1时)
# 将K、V分成多个块,并行计算每个块的贡献
partial_outputs = parallel_for block in K_blocks:
    scores = Q @ block.K.T
    attn = softmax(scores)
    partial = attn @ block.V

# 合并所有部分输出
output = merge(partial_outputs)

性能:单个Token生成速度提升8倍。

技术3:投机采样(Speculative Sampling)

自回归生成的低效率

问题:GPU利用率低

自回归生成的特点:

# 每次生成1个Token
tokens = [start_token]
for i in range(100):
    next_token = model(tokens)  # 模型前向传播
    tokens.append(next_token)   # 等待上一步完成

性能分析

  • 大模型(如LLaMA-70B):70B参数,每个Token需要加载140GB数据(FP16)
  • A100 GPU内存带宽:1.5 TB/s
  • 单个Token生成时间:140GB / 1.5TB/s ≈ 93ms
  • 大部分时间在等待内存访问,GPU计算单元利用率<20%

结论:生成1个Token和生成多个Token的成本几乎一样(都是内存带宽瓶颈),但每次只生成1个Token。

投机采样的核心思想

用小模型快速生成多个候选Token,然后用大模型并行验证

类比理解

  • 传统方式:老师逐个批改作业(每次1份,批改100次)
  • 投机采样:学生先自己写答案,老师一次性批改多份(批改10次,每次10份)

投机采样算法

步骤1:小模型生成候选序列

使用一个小而快的模型(draft model)生成 kk 个候选Token:

# 小模型(例如:LLaMA-7B)
draft_model = load_small_model()

# 快速生成k个候选Token
candidates = []
for i in range(k):
    next_token = draft_model(tokens + candidates)
    candidates.append(next_token)

# 候选序列:[token_1, token_2, ..., token_k]

性能:小模型速度快(例如LLaMA-7B比LLaMA-70B快10倍)

步骤2:大模型并行验证

使用大模型(target model)并行验证所有候选Token:

# 大模型(例如:LLaMA-70B)
target_model = load_large_model()

# 并行验证:一次前向传播同时验证k个Token
# 输入:[token_0, candidate_1, candidate_2, ..., candidate_k]
logits = target_model(tokens + candidates)  # (k+1, vocab_size)

# 对每个位置,计算大模型的概率分布
p_target = softmax(logits, dim=-1)

关键观察

  • 大模型可以并行处理所有候选Token(类似于训练时的teacher forcing)
  • 计算 kk 个Token的时间 ≈ 计算1个Token的时间(都是内存带宽瓶颈)
步骤3:接受或拒绝候选Token

从左到右逐个检查候选Token,根据概率分布决定是否接受:

def verify_candidates(candidates, p_draft, p_target):
    """
    candidates: 候选Token列表
    p_draft: 小模型的概率分布 (k, vocab_size)
    p_target: 大模型的概率分布 (k, vocab_size)
    """
    accepted = []

    for i, token in enumerate(candidates):
        # 计算接受概率
        p_accept = min(1.0, p_target[i, token] / p_draft[i, token])

        # 随机接受或拒绝
        if random.random() < p_accept:
            accepted.append(token)
        else:
            # 拒绝该Token,停止验证后续Token
            # 从修正的分布中采样新Token
            p_corrected = max(0, p_target[i] - p_draft[i])
            p_corrected = p_corrected / p_corrected.sum()
            new_token = sample(p_corrected)
            accepted.append(new_token)
            break

    return accepted

接受条件

接受概率=min(1,Ptarget(token)Pdraft(token))\text{接受概率} = \min\left(1, \frac{P_{\text{target}}(\text{token})}{P_{\text{draft}}(\text{token})}\right)

直觉理解

  • 如果大模型也认为该Token概率高(PtargetPdraftP_{\text{target}} \geq P_{\text{draft}}),则接受
  • 如果大模型认为该Token概率低,则以一定概率拒绝
完整流程示例

初始状态

  • 已生成:["今天", "天气"]
  • 目标:继续生成

Step 1:小模型生成候选(k=4)

输入:["今天", "天气"]
小模型生成候选:["很", "不", "错", "呢"]

Step 2:大模型并行验证

输入:["今天", "天气", "很", "不", "错", "呢"]
大模型输出概率分布(简化):
- 位置0("今天"后):P_target("天气") = 0.6, P_draft("天气") = 0.5
- 位置1("天气"后):P_target("很") = 0.4, P_draft("很") = 0.3
- 位置2("很"后):P_target("不") = 0.1, P_draft("不") = 0.4
- 位置3("不"后):P_target("错") = 0.05, P_draft("错") = 0.3

Step 3:逐个验证

验证"很":
  接受概率 = min(1, 0.4/0.3) = 1.0
  随机数 = 0.5 < 1.0 → 接受✓

验证"不":
  接受概率 = min(1, 0.1/0.4) = 0.25
  随机数 = 0.8 > 0.25 → 拒绝✗
  从修正分布中采样新Token → "好"

停止验证后续Token("错""呢"

最终接受:["很", "好"](2个Token)

投机采样的性能分析

加速比

假设:

  • 候选序列长度:kk
  • 平均接受率:α\alpha(例如0.6,即60%的候选Token被接受)
  • 小模型速度:sdrafts_{\text{draft}}(例如10 tokens/s)
  • 大模型速度:stargets_{\text{target}}(例如1 token/s)

每轮迭代

  • 时间成本:ksdraft+1starget\frac{k}{s_{\text{draft}}} + \frac{1}{s_{\text{target}}}(小模型生成+大模型验证)
  • 平均接受Token数:αk\alpha \cdot k

有效速度

sspeculative=αkksdraft+1stargets_{\text{speculative}} = \frac{\alpha \cdot k}{\frac{k}{s_{\text{draft}}} + \frac{1}{s_{\text{target}}}}

实际例子k=4,α=0.6,sdraft=10,starget=1k=4, \alpha=0.6, s_{\text{draft}}=10, s_{\text{target}}=1):

sspeculative=0.6×4410+11=2.40.4+1=1.71 tokens/ss_{\text{speculative}} = \frac{0.6 \times 4}{\frac{4}{10} + \frac{1}{1}} = \frac{2.4}{0.4 + 1} = 1.71 \text{ tokens/s}

加速比1.71/1=1.711.71 / 1 = 1.71

关键因素

1. 接受率 α

  • α 越高,加速越明显
  • α 取决于小模型和大模型的相似度
  • 通常:α ∈ [0.5, 0.8]

2. 候选长度 k

  • k 太小:加速不明显
  • k 太大:接受率下降,浪费计算
  • 最优值:通常 k ∈ [4, 10]

3. 小模型选择

  • 越接近大模型,接受率越高
  • 但小模型不能太大(否则速度慢)
  • 实践:大模型的1/10参数量(如70B用7B)
实际性能数据
场景标准生成投机采样加速比
LLaMA-70B (7B draft, k=4)1.2 tokens/s2.1 tokens/s1.75x
GPT-3.5 (GPT-2 draft, k=6)15 tokens/s28 tokens/s1.87x
CodeLLaMA-34B (7B draft, k=5)2.5 tokens/s4.3 tokens/s1.72x

投机采样的变体

1. Medusa:多头投机采样

问题:需要维护两个模型(大模型+小模型),增加部署复杂度。

解决方案:在大模型上添加多个轻量级"投机头",每个头预测未来某个位置的Token。

# 大模型的最后一层输出
hidden = large_model.forward(tokens)  # (seq_len, hidden_dim)

# 添加多个投机头(每个头是一个轻量级MLP)
head_1 = mlp_head_1(hidden[-1])  # 预测下一个Token
head_2 = mlp_head_2(hidden[-1])  # 预测下下个Token
head_3 = mlp_head_3(hidden[-1])  # 预测下下下个Token

# 生成候选序列
candidates = [
    sample(head_1),
    sample(head_2),
    sample(head_3)
]

# 大模型验证(与投机采样相同)
logits = large_model.forward(tokens + candidates)
# ...验证和接受

优势

  • 只需一个模型,部署简单
  • 投机头可以并行计算,几乎无额外延迟

劣势

  • 需要额外训练投机头
  • 接受率略低于使用独立小模型
2. 自投机采样(Self-Speculative Decoding)

思想:用大模型的早期退出(Early Exit)作为草稿模型。

# 大模型有L层
# 早期退出:只运行前L/2层
draft_output = large_model.forward(tokens, num_layers=L//2)
candidates = generate_candidates(draft_output)

# 完整运行:验证候选
target_output = large_model.forward(tokens + candidates, num_layers=L)
# ...验证

优势

  • 无需额外模型
  • 早期层计算快

劣势

  • 早期层表达能力有限,接受率低
3. 批量投机采样

场景:同时处理多个请求(batch推理)。

挑战:不同请求的候选序列长度可能不同(有的接受4个,有的接受2个)。

解决方案:动态批处理

# 每个请求独立验证
for request in batch:
    candidates = draft_model.generate(request, k=4)
    accepted = verify(candidates)
    request.append(accepted)

# 重新组织batch(根据接受的Token数)
# 继续下一轮

三大技术的协同作用

技术组合

在实际部署中,这三大技术通常组合使用

┌─────────────────────────────────────────┐
│  PagedAttention                          │  内存管理层
│  - 动态分配KV Cache页                    │
│  - 内存利用率:20% → 90%                 │
└────────────┬────────────────────────────┘
             │
┌────────────▼────────────────────────────┐
│  FlashAttention                          │  计算优化层
│  - 分块计算,减少HBM访问                 │
│  - 计算速度:提升2-9倍                   │
└────────────┬────────────────────────────┘
             │
┌────────────▼────────────────────────────┐
│  Speculative Sampling                    │  算法优化层
│  - 并行验证多个Token                     │
│  - 生成速度:提升1.5-2倍                 │
└─────────────────────────────────────────┘

实际系统性能

以vLLM(结合PagedAttention + FlashAttention)为例:

指标HuggingFace TransformersvLLM提升
吞吐量(LLaMA-13B)0.3 req/s2.5 req/s8.3倍
内存利用率20%90%4.5倍
最大batch size8648倍
延迟(P50)2.1s0.8s2.6倍

再加上投机采样(Medusa):

指标vLLMvLLM + Medusa额外提升
单请求延迟800ms450ms1.8倍
吞吐量2.5 req/s4.2 req/s1.7倍

总提升(相比原始实现):

  • 吞吐量:8.3 × 1.7 = 14倍
  • 内存效率:4.5倍
  • 延迟:2.6 × 1.8 = 4.7倍

实际应用与最佳实践

技术选型建议

场景推荐技术原因
在线API服务(高并发)PagedAttention + FlashAttention需要高吞吐量和内存效率
单用户聊天(低延迟)FlashAttention + 投机采样降低单请求延迟
长文本处理(>8K)FlashAttention序列越长,加速越明显
资源受限(小显存)PagedAttention提升内存利用率
代码生成投机采样代码模式性强,高接受率

开源实现

项目技术特点
vLLMPagedAttention + FlashAttention高吞吐量推理服务
FlashAttentionFlashAttention-2官方实现,性能最优
Medusa多头投机采样单模型投机采样
TensorRT-LLMFlashAttention + 多种优化NVIDIA官方,生产级
Text Generation InferenceFlashAttention + 连续批处理HuggingFace官方

部署示例

vLLM部署(PagedAttention + FlashAttention)
from vllm import LLM, SamplingParams

# 初始化模型
llm = LLM(
    model="meta-llama/Llama-2-13b-hf",
    tensor_parallel_size=2,  # 使用2个GPU
    gpu_memory_utilization=0.95,  # 高内存利用率(PagedAttention)
)

# 批量推理
prompts = ["Explain quantum computing", "Write a poem about AI"]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    print(output.outputs[0].text)

自动启用

  • PagedAttention:自动管理KV Cache
  • FlashAttention:自动使用(如果可用)
投机采样部署
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载大模型和小模型
target_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf")
draft_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# 投机采样生成
from transformers import SpeculativeDecoder

decoder = SpeculativeDecoder(
    target_model=target_model,
    draft_model=draft_model,
    k=5  # 候选序列长度
)

output = decoder.generate(
    input_ids,
    max_new_tokens=100
)

小结

PagedAttention

核心思想:将KV Cache分页,按需分配,类似操作系统的虚拟内存。

优势

  • 内存利用率:20% → 90%(提升4.5倍)
  • 支持动态批处理和内存共享
  • 吞吐量提升5-10倍

适用场景:高并发API服务、批量推理

FlashAttention

核心思想:分块计算,在SRAM中完成尽可能多的操作,减少HBM访问。

优势

  • 显存访问:减少10-100倍
  • 计算速度:提升2-9倍(序列越长越明显)
  • 无精度损失(数学等价)

训练 vs 推理

  • 训练时:节省显存(不保存中间矩阵),但反向传播需要重计算,总体仍更快
  • 推理时:无重计算开销,纯粹的性能提升,加速更明显

适用场景

  • 训练:长序列训练、显存受限场景
  • 推理:所有推理场景(尤其是长文本)

投机采样

核心思想:用小模型快速生成候选,大模型并行验证。

优势

  • 生成速度:提升1.5-2倍
  • 保持输出质量(与原模型等价)
  • 无需修改模型结构

适用场景:低延迟生成、代码生成、模式性强的任务

综合效应

三大技术组合使用,可实现:

  • 吞吐量:提升10-15倍
  • 延迟:降低3-5倍
  • 成本:降低50-70%(更少GPU,更高利用率)

这三大技术是大模型推理优化的基石,几乎所有现代推理系统都基于它们构建。理解这些技术的原理,对于优化部署、降低成本、提升用户体验至关重要。