nano-vllm(2):KV Cache、PagedAttention和nano-vllm的实现

116 阅读22分钟

0. 什么是KV Cache

0. nano-vllm:大模型推理原理和流程中我们阐述了大语言模型作为一种自回归模型的基本工作流程,其主要工作阶段分为:

  1. prefill阶段:模型处理全部的Prompt,进行前向计算,直到生成第一个输出token;这个阶段的关键指标是TTFT(Time To First Token),即生成第一个token所需的时间;
  2. decode阶段:一旦prefill完成,模型进入decode阶段,逐个生成剩余token;这个阶段的关键指标是TPOT(Time Per Output Token),即生成每个响应token所需的平均时间。

0.1 自回归模型

前面提到的decode阶段,就是自回归模型的典型特征:模型上一步的输出会被当作是下一步的输入。而transformer是一种神经网络架构,是实现自回归语言模型当前最强大、最主流的架构。transformer模型的整体架构如下:

如上所示,Scaled Dot-Product Attention(缩放点积注意力)是Transformer模型中最核心、最基本的运算单元,其基本的公式如下:

Attention(Q,K,V)=softmax((QKT)/sqrt(dk))VAttention(Q, K, V) = softmax( (Q K^T) / sqrt(d_k) ) V

其中(具体含义咱们在这就不深究了):

  • Q查询向量,代表当前需要“寻找什么信息”。
  • K键向量,代表历史信息“有什么内容的标签”。
  • V值向量,代表历史信息“具体的内容是什么”。

0.2 没有KV Cache

上图所示,对应token的Attention计算如下:

Att1(Q,K,V)=softmax((Q1K1T)/sqrt(dk))V1Att_1(Q, K, V) = softmax( (Q_1K_1^T) / sqrt(d_k) )V_1
Att2(Q,K,V)=softmax((Q2K1T)/sqrt(dk))V1+softmax((Q2K2T)/sqrt(dk))V2Att_2(Q, K, V) = softmax( (Q_2K_1^T) / sqrt(d_k) )V_1 + softmax( (Q_2K_2^T) / sqrt(d_k) )V_2
Attn(Q,K,V)=softmax((QnK1T)/sqrt(dk))V1+softmax((QnK2T)/sqrt(dk))V2+...+softmax((QnKnT)/sqrt(dk))VnAtt_n(Q, K, V) = softmax( (Q_nK_1^T) / sqrt(d_k) )V_1 + softmax( (Q_nK_2^T) / sqrt(d_k) )V_2 + ... + softmax( (Q_nK_n^T) / sqrt(d_k) )V_n

为了严格保证模型的“因果性”或“自回归属性”,Decoder有Causal Mask,在推理的时候前面已经生成的字符不需要与后面的字符产生attention,所以QKT{\scriptstyle QK^T}矩阵的右上半部会被mask,即推导Tokenk\scriptstyle Token_k时不应该受到Kk+1,Kk+2,...,Kk+n\scriptstyle K_k+1,K_k+2,...,K_k+n对其的影响。

从而观察以上计算的时候,会发现每次会有大量冗余的计算,即原本计算的K和V都被重新计算了一次,而每一步的KV都是确定的,因此完全可以被缓存起来。而每次计算的Attk\scriptstyle Att_k除了和之前的KV相关,只和Qk\scriptstyle Q_k相关,因此Q没有必要缓存,每次只需计算当前Q即可。

0.3 引入KV Cache

下面的Step1-4即为“Transformers KV Caching Explained”中提供的不使用KV Cache和使用KV Cache的对比:

下面借助“Transformers KV Caching Explained”的例子测试一下:

import numpy as np
import time
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

path = os.path.expanduser(
        "~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path).to(device)

for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=2000)
    times.append(time.time() - start)
  print(f"{'TPOT with' if use_cache else 'TPOT without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")



for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1)
    times.append(time.time() - start)
  print(f"{'TTFT with' if use_cache else 'TTFT without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

结果显示,KV Cache对于TPOT的提升较大,这很好理解,因为其发挥作用就是在自回归解码阶段;而对于预填充阶段的TTFT耗时几乎没有影响。

TPOT with KV caching: 19.32 +- 0.058 seconds
TPOT without KV caching: 123.952 +- 0.363 seconds
TTFT with KV caching: 0.012 +- 0.0 seconds
TTFT without KV caching: 0.011 +- 0.0 seconds

0.3.1 KV Cache带来的问题

KV Cache 本质上是拿空间换时间的操作,因此不可避免带来一些问题:

  1. 显存占用更大:显存增长与序列长度成正比,在资源受限的设备上(比如消费级显卡),显存可能首先被KV Cache耗尽,从而严重限制了批量大小(batch size)和可生产的最大序列长度。
  2. 显存带宽瓶颈:虽然减少了计算量,但是压力转移到显存访问中,推理速度的瓶颈从计算瓶颈转移到内存带宽瓶颈;且频繁地分配和释放缓存,还可能导致显存碎片化。

1. Paged Attention

1.1 KV Cache的内存碎片

早期的大模型推理系统,采用的也是连续的显存分配系统,即将同一个请求的KV Cache存储在一个连续空间里,预分配的机制会在请求开始时,就按最大可能生成的长度(如 2048 token)为每个请求分配一整块连续内存空间,这种连续分配机制存在着严重的显存碎片。结合PagedAttention 的论文《Efficient Memory Management for Large Language Model Serving with PagedAttention》,可以看到,传统的推理系统(以Orca为例)的显存有效占比极低,除绿色部分外皆为浪费;而vLLM引入的PagedAttention机制则使得其KV Cache的使用率达到了96.3%。

1.2 PagedAttention机制

类似于操作系统虚拟内存,vLLM提出了PagedAttention,具体来说,PagedAttention将每个生成序列的KV Cache划分为多个block,每个block中包含固定数量的key和value向量。和虚拟内存的映射表一样,PagedAttention也存在一个映射表,在Attention计算要用到KV Cache的时候,通过Block Table找到这个序列对应的block,进而从block中取出对应的KV向量。其基本图示如下:

1.3 引用计数

论文《Efficient Memory Management for Large Language Model Serving with PagedAttention》中介绍了vLLM支持多样化的推理策略,并在这些推理策略中表现优异,比如并行采样(Parallel Sampling)束搜索(Beam Search) 以及 共享前缀(Shared Prefix) ,其实他们都通过给内存块(Block)添加引用计数,来实现KV Cache的复用,从而显著节省显存、提高吞吐量。其实在vLLM的v1中,束搜索(Beam Search已经被移出核心,并行采样(Parallel Sampling)也并不会发生写时复制,这里就不详述了。

2. nano-vllm的实现

2.1 存储相关模块关系

graph TB
    subgraph "物理层 - BlockManager"
        BM["BlockManager<br/>nanovllm/engine/block_manager.py"]
        BLOCKS["物理Block数组<br/>Block[0..N-1]"]
        FREE["空闲队列<br/>free_block_ids"]
        USED["使用集合<br/>used_block_ids"]
        HASH["哈希映射<br/>hash_to_block_id"]
        
        BM --> BLOCKS
        BM --> FREE
        BM --> USED
        BM --> HASH
    end
    
    subgraph "存储层 - ModelRunner"
        MR["ModelRunner<br/>nanovllm/engine/model_runner.py"]
        KVCACHE["KV Cache张量<br/>[2, layers, blocks, block_size, heads, dim]"]
        
        MR --> KVCACHE
    end
    
    subgraph "逻辑层 - Sequence"
        SEQ["Sequence<br/>nanovllm/engine/sequence.py"]
        BLOCKTABLE["逻辑块表<br/>block_table: list[int]"]
        TOKENS["Token序列<br/>token_ids: list[int]"]
        
        SEQ --> BLOCKTABLE
        SEQ --> TOKENS
    end
    
    subgraph "计算层 - Attention"
        ATT["Attention<br/>nanovllm/layers/attention.py"]
        KCACHE["k_cache视图"]
        VCACHE["v_cache视图"]
        
        ATT --> KCACHE
        ATT --> VCACHE
    end
    
    BLOCKTABLE -.->|映射到| BLOCKS
    BLOCKS -.->|数据存储在| KVCACHE
    KVCACHE -.->|层视图| KCACHE
    KVCACHE -.->|层视图| VCACHE
    
    style BLOCKS fill:#e1f5ff
    style BLOCKTABLE fill:#fff4e1
    style KVCACHE fill:#ffe1e1

存储相关的模块和作用如下:

  1. 物理BlockBlockManager管理,包含元数据但不存储实际的KV数据
  2. 逻辑Block体现在Sequence.block_table中,是物理block_id的有序列表
  3. 真实存储ModelRunner.kv_cache张量中,按物理block_id索引
  4. Attention层通过k_cachev_cache视图访问对应层的KV数据
  5. slot_mapping提供token到物理存储位置的精确映射,实现PagedAttention

这种设计实现了逻辑与物理的清晰分离:Sequence关注逻辑序列结构,BlockManager管理物理资源分配,ModelRunner提供实际存储,Attention层专注于计算。

2.2 初始化阶段

graph TB
    subgraph "初始化阶段"
        A[ModelRunner.__init__] --> B[allocate_kv_cache]
        B --> C[计算GPU可用内存]
        C --> D[分配KV Cache张量]
        D --> E[创建BlockManager]
    end

在nano-vllm中,采用预分配的方式一次性申请内存,其申请逻辑如下:

def allocate_kv_cache(self):
    config = self.config
    hf_config = config.hf_config
    free, total = torch.cuda.mem_get_info()
    used = total - free
    peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
    current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
    num_kv_heads = hf_config.num_key_value_heads // self.world_size
    head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
    block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
    config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
    assert config.num_kvcache_blocks > 0
    self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
    layer_id = 0
    for module in self.model.modules():
        if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
            module.k_cache = self.kv_cache[0, layer_id]
            module.v_cache = self.kv_cache[1, layer_id]
            layer_id += 1

以上代码要想理解,得先看一下我们两个参数配置,一个是Config模块中的配置:

class Config:
    # ...
    tensor_parallel_size: int = 1
    gpu_memory_utilization: float = 0.9
    kvcache_block_size: int = 256
    num_kvcache_blocks: int = -1
    # ...

我们只摘取与我们相关的参数,其中:

  • tensor_parallel_size:张量并行大小,1-8之间,指定使用的GPU数量进行模型并行
  • gpu_memory_utilization:GPU内存利用率,0.0-1.0之间,指定KV Cache可使用的GPU内存比例;
  • kvcache_block_size:KV Cache块大小,每个块包含的token数量(必须能被256整除);
  • num_kvcache_blocks:KV Cache总块数,根据GPU内存动态计算(-1表示未初始化);

另一个就是模型的config.json文件,在我们例子中模型文件夹下~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca下:

{                                                                                                                                                                                                                                                                                  
  "architectures": [
    "Qwen3ForCausalLM"
  ],  
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 40960,
  "max_window_layers": 28, 
  "model_type": "qwen3",
  "num_attention_heads": 16, 
  "num_hidden_layers": 28, 
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000,
  "sliding_window": null,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.51.0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

所以,内存计算的详解如下:

1. GPU内存状态获取

free, total = torch.cuda.mem_get_info()           # 获取空闲和总内存  
used = total - free                               # 计算已使用内存  
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]    # 历史峰值  
current = torch.cuda.memory_stats()["allocated_bytes.all.current"] # 当前分配

系统通过torch.cuda.mem_get_info()获取GPU内存的实时状态,并结合内存统计信息计算可用于KV Cache的内存空间。

2. 单块内存计算

num_kv_heads = hf_config.num_key_value_heads // self.world_size  
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)  
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize

关键计算公式

  • 2:Key和Value两个张量
  • hf_config.num_hidden_layers:模型层数,这里是28
  • self.block_size:每块token数(默认256)
  • num_kv_heads:KV头数(考虑tensor并行),因为我只有单卡,所以这里是 8/1=8
  • head_dim:头维度,如上是128
  • hf_config.torch_dtype.itemsize:数据类型字节数,因为torch_type是bfloat16,所以应该是两个字节;

所以block_bytes计算出来是 blockbytes=2×28×256×8×128×2=28×1024×1024block_bytes = 2 \times 28 \times 256 \times 8 \times 128 \times 2 = 28 \times 1024 \times 1024,也就是28M字节。

3. 可分配块数计算

config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes

系统根据配置的gpu_memory_utilization(默认0.9)计算可用内存,并除以单块大小得到可分配的块数。这里的公式有点难以理解:

  • total * gpu_memory_utilization:预留90%的显存给本进程;(本机total大概是23.48G,乘以0.9就是21.13G左右)
  • peak:从模型warmup开始计算本进程申请的显存峰值(断点调试时,peak大概是1.58G);
  • used:当前整个机器上所有的显存占用,包含其他服务占用的显存+本进程现阶段占用的显存(断点调试时,used大概是3.69G);
  • current:本进程现阶段占用的显存(断点调试时,current大概是1.14G);

total * config.gpu_memory_utilization - used这很好理解,但是为什么要再减去peak-current呢?其实就是预留peak-current的显存给框架,害怕框架后面有什么操作还需要达到这个峰值大小。根据以上的调试数据,最后可用的显存在17G = 17408M,所以对于28M字节的block块大小,可以一共有621块block,即config.num_kvcache_blocks = 621

4.KV Cache张量结构

6维张量分配

self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks,   
                           self.block_size, num_kv_heads, head_dim)

根据前面的关键计算公式那里,kv_cache有6个维度:

  • 2:Key和Value两个分离的张量
  • hf_config.num_hidden_layers:模型层数
  • config.num_kvcache_blocks:物理块数量
  • self.block_size:每块包含的token数(默认256)
  • num_kv_heads:KV注意力头数
  • head_dim:每个注意力头的维度

注意这里的公式比上面计算字节数少了一个hf_config.torch_dtype.itemsize,这是因为框架内部会根据dtype来确认这个大小,所以不用担心此时的torch.empty操作申请的显存不对。

以下是我打断点调试时的显存申请视图,可以发现,在进行torch.empty操作时,大约申请了17GB的显存。

2.3 实现流程

graph TB
    subgraph "请求处理"
        F[用户请求] --> G[创建Sequence]
        G --> H[Scheduler.add]
        H --> I[Scheduler.schedule]
    end
    
    subgraph "内存分配"
        I --> J{Prefill阶段?}
        J -->|是| K[BlockManager.allocate]
        J -->|否| L[BlockManager.may_append]
        K --> M[分配物理块]
        L --> N[扩展序列块]
    end
    
    subgraph "注意力计算"
        M --> O[ModelRunner.prepare_prefill]
        N --> P[ModelRunner.prepare_decode]
        O --> Q[设置slot_mapping]
        P --> Q
        Q --> R[Attention.forward]
        R --> S[FlashAttention计算]
        S --> T[存储KV到cache]
    end
    
    subgraph "块管理"
        T --> U[更新Block引用计数]
        U --> V[检查序列完成]
        V -->|完成| W[BlockManager.deallocate]
        V -->|继续| X[下一轮decode]
    end

在推理过程中,请求到来后有关KV Cache的请求过程基本如上所示。下面我们将分步拆解一下每层重点的操作。

2.3.1 prefill

nano-vllm的调度器不支持Chunked Prefill的调度模式,其调度器比较简单,首先处理prefill阶段的请求,然后处理decode阶段的请求,从代码可以看出来,在decode阶段来了一个其他请求,调度器也会优先处理新到的请求。

1.Scheduler的prefill序列
def schedule(self) -> tuple[list[Sequence], bool]:
    # prefill
    scheduled_seqs = []
    num_seqs = 0
    num_batched_tokens = 0
    while self.waiting and num_seqs < self.max_num_seqs:
        seq = self.waiting[0]
        if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
            break
        num_seqs += 1
        self.block_manager.allocate(seq)
        num_batched_tokens += len(seq) - seq.num_cached_tokens
        seq.status = SequenceStatus.RUNNING
        self.waiting.popleft()
        self.running.append(seq)
        scheduled_seqs.append(seq)
    if scheduled_seqs:
        return scheduled_seqs, True

除了做了一些状态判断、状态设置,waiting队列的出队,running队列的入队,schedule主要做了进行了self.block_manager.allocate(seq)操作。

2.BlockManager 显存申请
def allocate(self, seq: Sequence):
    assert not seq.block_table
    h = -1
    cache_miss = False
    for i in range(seq.num_blocks):
        token_ids = seq.block(i)
        h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
        block_id = self.hash_to_block_id.get(h, -1)
        if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
            cache_miss = True
        if cache_miss:
            block_id = self.free_block_ids[0]
            block = self._allocate_block(block_id)
        else:
            seq.num_cached_tokens += self.block_size
            if block_id in self.used_block_ids:
                block = self.blocks[block_id]
                block.ref_count += 1
            else:
                block = self._allocate_block(block_id)
        if h != -1:
            block.update(h, token_ids)
            self.hash_to_block_id[h] = block_id
        seq.block_table.append(block_id)

可以看到,在这里会通过hash计算和token_ids的匹配,看看是否有相同的tokens,如果有的话会引用计数+1。然后通过seq.block_table映射此时的物理块id。

然后,在step函数中会继续调用ModelRunner.run,然后在prepare_prefill阶段,slot通过简单的线性计算与block_id建立对应关系。

ModelRunner 数据准备

核心映射公式

ModelRunner.prepare_prefill()中,slot_mapping的计算逻辑如下:

for i in range(seq.num_cached_blocks, seq.num_blocks):  
    start = seq.block_table[i] * self.block_size  
    if i != seq.num_blocks - 1:  
        end = start + self.block_size  
    else:  
        end = start + seq.last_block_num_tokens   
    slot_mapping.extend(list(range(start, end)))

基础计算公式:

slot = block_id * block_size + token_offset_in_block  
  • block_id: 来自seq.block_table[i],是物理块标识符
  • block_size: 默认256,每个块包含的token数量
  • token_offset_in_block: 块内token的偏移量(0到block_size-1)

实际示例,假设:

  • block_size = 256
  • seq.block_table = [5, 12, 8]
  • 序列长度为300个token

计算过程:

Block 0 (block_id=5): slots 5*2565*256+255 = 1280-1535  
Block 1 (block_id=12): slots 12*25612*256+43 = 3072-3115 (最后一块只有44个token)  
关键数据结构 Sequence.block_table

存储逻辑位置到物理block_id的映射:

self.block_table = []  # [5, 12, 8, ...]
Context.slot_mapping

存储每个token的物理存储slot位置:

slot_mapping: torch.Tensor | None = None  # [1280, 1281, ..., 3115]
Attention 存储访问机制

在Attention层,通过slot_mapping直接索引KV Cache张量:

def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):  
    context = get_context()  
    if k_cache.numel() and v_cache.numel():  
        store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)

Triton内核使用slot作为索引直接访问物理存储位置。

2.3.2 decode

1. Scheduler调度Decode序列

Scheduler.schedule()的decode阶段,系统处理运行队列中的序列:

# decode  
while self.running and num_seqs < self.max_num_seqs:  
    seq = self.running.popleft()  
    while not self.block_manager.can_append(seq):  
        if self.running:  
            self.preempt(self.running.pop())  
        else:  
            self.preempt(seq)  
            break  
    else:  
        num_seqs += 1  
        self.block_manager.may_append(seq)  
        scheduled_seqs.append(seq)

调度逻辑

  • 从运行队列取出序列进行decode
  • 检查是否可以扩展序列,必要时抢占其他序列
  • 调用BlockManager.may_append()扩展KV Cache
2. BlockManager动态扩展

BlockManager.may_append()处理序列的KV Cache扩展:

def may_append(self, seq: Sequence):  
    block_table = seq.block_table  
    last_block = self.blocks[block_table[-1]]  
    if len(seq) % self.block_size == 1:  
        assert last_block.hash != -1  
        block_id = self.free_block_ids[0]  
        self._allocate_block(block_id)  
        block_table.append(block_id)  
    elif len(seq) % self.block_size == 0:  
        assert last_block.hash == -1  
        token_ids = seq.block(seq.num_blocks-1)  
        prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1  
        h = self.compute_hash(token_ids, prefix)  
        last_block.update(h, token_ids)  
        self.hash_to_block_id[h] = last_block.block_id  
    else:  
        assert last_block.hash == -1

扩展策略

  • 跨块边界:当序列长度跨块边界时分配新物理块
  • 块完成:当块填满时计算哈希并加入prefix cache
  • 块内扩展:在同一块内继续填充,无需额外操作
3. ModelRunner数据准备

ModelRunner.prepare_decode()为decode阶段准备执行数据:

def prepare_decode(self, seqs: list[Sequence]):  
    input_ids = []  
    positions = []  
    slot_mapping = []  
    context_lens = []  
    for seq in seqs:  
        input_ids.append(seq.last_token)  
        positions.append(len(seq) - 1)  
        context_lens.append(len(seq))  
        slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)  
    input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)  
    positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)  
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)  
    context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)  
    block_tables = self.prepare_block_tables(seqs)  
    set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)  
    return input_ids, positions

关键准备

  • 新token位置:计算新token的slot位置seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
  • 序列长度:记录每个序列的当前长度用于attention计算
  • 块表准备:为FlashAttention准备块表张量
4. Context设置Decode模式

set_context()设置decode阶段的执行上下文:

set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)

Decode上下文特点

  • is_prefill = False:标识decode阶段
  • context_lens:每个序列的长度,用于FlashAttention
  • slot_mapping:新token的存储位置
  • block_tables:完整的块表映射
5. Attention层Decode计算

Attention.forward()在decode阶段使用专门的FlashAttention接口:

else:    # decode  
    o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,  
                                cache_seqlens=context.context_lens, block_table=context.block_tables,   
                                softmax_scale=self.scale, causal=True)

Decode优化

  • KV Cache复用:直接从cache读取历史KV数据
  • 高效attention:使用flash_attn_with_kvcache优化单token生成
  • 块表索引:通过block_tablecache_seqlens高效访问KV数据

2.4 Prefix Cache

1. BlockManager引用计数

首先明确一点,在nano-vllm中,Prefix Cache的技术依赖就是引用计数,在BlockManager.allocate函数中,会判断引用是否缓存命中:

def allocate(self, seq: Sequence):
    assert not seq.block_table
    h = -1
    cache_miss = False
    for i in range(seq.num_blocks):
        token_ids = seq.block(i)
        h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
        block_id = self.hash_to_block_id.get(h, -1)
        if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
            cache_miss = True
        if cache_miss:
            block_id = self.free_block_ids[0]
            block = self._allocate_block(block_id)
        else:
            seq.num_cached_tokens += self.block_size
            if block_id in self.used_block_ids:
                block = self.blocks[block_id]
                block.ref_count += 1
            else:
                block = self._allocate_block(block_id)
        if h != -1:
            block.update(h, token_ids)
            self.hash_to_block_id[h] = block_id
        seq.block_table.append(block_id)

阅读以上代码可以发现:

  • 缓存命中只发生在prefill阶段,且缓存命中只在第一个token命中之后才有可能发生:因为一旦第一个block没有命中,那么cache_miss就会被置为True,那之后即使满足了hash值一致且token_ids一样的情形,也无法修改cache_miss的值。所以引用计数也只为Prefix Cache服务。
  • 发生缓存命中后,会通过num_cached_tokens记录命中缓存的tokens数。
2. ModelRunner数据准备
def prepare_prefill(self, seqs: list[Sequence]):
    input_ids = []
    positions = []
    cu_seqlens_q = [0]
    cu_seqlens_k = [0]
    max_seqlen_q = 0
    max_seqlen_k = 0
    slot_mapping = []
    block_tables = None
    for seq in seqs:
        seqlen = len(seq)
        input_ids.extend(seq[seq.num_cached_tokens:])
        positions.extend(list(range(seq.num_cached_tokens, seqlen)))
        seqlen_q = seqlen - seq.num_cached_tokens
        seqlen_k = seqlen
        cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
        cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
        max_seqlen_q = max(seqlen_q, max_seqlen_q)
        max_seqlen_k = max(seqlen_k, max_seqlen_k)
        if not seq.block_table:    # warmup
            continue
        for i in range(seq.num_cached_blocks, seq.num_blocks):
            start = seq.block_table[i] * self.block_size
            if i != seq.num_blocks - 1:
                end = start + self.block_size
            else:
                end = start + seq.last_block_num_tokens 
            slot_mapping.extend(list(range(start, end)))
    if cu_seqlens_k[-1] > cu_seqlens_q[-1]:    # prefix cache
        block_tables = self.prepare_block_tables(seqs)
    input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
    positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
    cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
    cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
    set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
    return input_ids, positions

以上代码在判断if cu_seqlens_k[-1] > cu_seqlens_q[-1]时认为发生缓存命中,这里需要注意的是,seqlen_q表示需要处理的query长度,它等于seq的token总数减去缓存token数,seqlen_k表示总的key长度,即token总数。而cu_seqlens_q表示为缓存token的累积长度,cu_seqlens_k表示完整序列的累积长度,假设有两个序列:

  • seq1: 长度100,已缓存80 tokens → seqlen_q=20, seqlen_k=100
  • seq2: 长度50,已缓存0 tokens → seqlen_q=50, seqlen_k=50

计算结果:

cu_seqlens_q = [0, 20, 70]   # [起始, seq1结束, seq2结束]  
cu_seqlens_k = [0, 100, 150]  # [起始, seq1结束, seq2结束]

所以if cu_seqlens_k[-1] > cu_seqlens_q[-1]时发生缓存命中,也就很好理解了。

还有一点需要注意的是,input_ids存储的是当前未被缓存的token_ids,且能通过以下方式来取为缓存token_ids,也是因为缓存命中只在第一个token命中之后才有可能发生。

input_ids.extend(seq[seq.num_cached_tokens:])

当发生缓存命中后,通过prepare_block_tables

def prepare_block_tables(self, seqs: list[Sequence]):
    max_len = max(len(seq.block_table) for seq in seqs)
    block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
    block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
    return block_tables

以上代码其实就是将不同seq的block_table向量组合在一起组成一个block_tables的矩阵,在没对齐的部分用-1填充,然后再转换成张量。

3. Context传递参数
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
4. Attention层的前缀匹配
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    context = get_context()
    k_cache, v_cache = self.k_cache, self.v_cache
    if k_cache.numel() and v_cache.numel():
        store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
    if context.is_prefill:
        if context.block_tables is not None:    # prefix cache
            k, v = k_cache, v_cache
        o = flash_attn_varlen_func(q, k, v,
                                   max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
                                   max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
                                   softmax_scale=self.scale, causal=True, block_table=context.block_tables)
    else:    # decode
        o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
                                    cache_seqlens=context.context_lens, block_table=context.block_tables, 
                                    softmax_scale=self.scale, causal=True)
    return o

需要注意的是k_cache, v_cache表示这一层Attention的所有KVCache,而传入的k, v表示的是非缓存的tokens计算出来的KV,在发生prefix cache的时候,计算的KV只包含未缓存部分的token_ids,所以在store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)的时候会将新计算的KV数据存储到分页KV cache中。

所以在发生了prefix cache的时候(判断if context.block_tables is not None)会使用k_cache, v_cache替代k, v去参与计算,其中传入的block_table提供的映射机制,在计算时拿到完整的KV(新计算的+缓存的)。

5. 一个例子
import os
from nanovllm import LLM, SamplingParams
from transformers import AutoTokenizer


def main():
    path = os.path.expanduser(
        "~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca")
    tokenizer = AutoTokenizer.from_pretrained(path)
    llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)

    # 构造一个长 prompt(超过一个块,比如300个token)
    long_text = " ".join(["test"] * 300)
    prompt = tokenizer.apply_chat_template([{"role": "user", "content": long_text}], tokenize=False,
                                           add_generation_prompt=True)

    # 第一次请求:max_tokens=1,让它 prefill 后进入 decode,保留 KV cache
    sp1 = SamplingParams(temperature=0.6, max_tokens=1)
    # 第二次请求:max_tokens=256,正常生成
    sp2 = SamplingParams(temperature=0.6, max_tokens=256)

    prompts = [prompt, prompt]
    sampling_params = [sp1, sp2]

    outputs = llm.generate(prompts, sampling_params)

    for i, out in enumerate(outputs):
        print(f"\n--- 输出 {i + 1} ---")
        print(f"Completion: {out['text']!r}")


if __name__ == "__main__":
    main()

例子如上,很明显两个Prompt一模一样,为的就是触发prefix cache,然后为了证明发生了prefix cache,所以我们加如下的打印:

1. BlockManager.allocate() - 缓存分配阶段

nanovllm/engine/block_manager.pyallocate()方法中添加打印 block_manager.py:59-83 :

def allocate(self, seq: Sequence):
    assert not seq.block_table
    h = -1
    cache_miss = False
    for i in range(seq.num_blocks):
        token_ids = seq.block(i)
        h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
        block_id = self.hash_to_block_id.get(h, -1)
        if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
            cache_miss = True
        if cache_miss:
            block_id = self.free_block_ids[0]
            block = self._allocate_block(block_id)
            print(f"Seq {seq.seq_id}: Cache miss at block {i}, allocated new block {block_id}")
        else:
            seq.num_cached_tokens += self.block_size
            if block_id in self.used_block_ids:
                block = self.blocks[block_id]
                block.ref_count += 1
            else:
                block = self._allocate_block(block_id)
            print(f"Seq {seq.seq_id}: Cache hit at block {i}, using shared block {block_id}")
        if h != -1:
            block.update(h, token_ids)
            self.hash_to_block_id[h] = block_id
        seq.block_table.append(block_id)

2. prepare_prefill() - Prefix cache检测

nanovllm/engine/model_runner.pyprepare_prefill()方法中添加打印 model_runner.py:154-155 :

if cu_seqlens_k[-1] > cu_seqlens_q[-1]:    # prefix cache
    print(f"Prefix cache detected! cu_seqlens_k={cu_seqlens_k[-1]}, cu_seqlens_q={cu_seqlens_q[-1]}")
    print(f"Cached tokens: {cu_seqlens_k[-1] - cu_seqlens_q[-1]}")
    block_tables = self.prepare_block_tables(seqs)
else:
    print("No prefix cache - processing full sequence")

3. Attention.forward() - 使用阶段

nanovllm/layers/attention.pyforward()方法中添加打印 attention.py:64-66 :

if context.block_tables is not None:    # prefix cache
    print(f"Using prefix cache in attention layer - block_tables shape: {context.block_tables.shape}")
    k, v = k_cache, v_cache
else:
    print("Using computed KV in attention layer (no prefix cache)")

去除前面初始化阶段的最后打印如下:

Generating:   0%|          | 0/2 [00:00<?, ?it/s]
Seq 4: Cache miss at block 0, allocated new block 0
Seq 4: Cache miss at block 1, allocated new block 1
Seq 5: Cache hit at block 0, using shared block 0
Seq 5: Cache miss at block 1, allocated new block 2
Prefix cache detected! cu_seqlens_k=616, cu_seqlens_q=360
Cached tokens: 256
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])
Using prefix cache in attention layer - block_tables shape: torch.Size([2, 2])

可以发现,例子中的第一个seq是Seq4,其因为有300个token左右,所以肯定是占据两个block(一个block大小256个token),所以其申请了block0 block1,且是cache miss,因为这本身就是处理的第一个seq了,不可能发生缓存命中。第二个seq是Seq5,其因为和Seq4的Prompt一样,所以也占据两个block,且第一个block发生缓存命中,因为第一个block填满了,而第二个block没有填满,所以不计算hash即无法缓存命中,所以申请了block 2

下面一起处理这两个Cache,由于发生了Prefix Cache,所以在模型的28层中,都打印了Using prefix cache in attention layer

2.5 一点小bug

Config模块,kvcache_block_size默认值是256,但是其明显是一个可以进行修改的变量;而在Sequence模块中,有如下定义,其中block_size写死是256,这二者的含义明显是相同的。

class Sequence:
    block_size = 256
    counter = count()

当我们设置Config中的kvcache_block_size不等于256时,ModelRunner 使用 config.kvcache_block_size 计算 slot_mapping 与物理偏移,Sequence 计算的块数与 ModelRunner 的块表维度不匹配,导致 KV cache 读写错位或越界。有以下例子:

import os
from nanovllm import LLM, SamplingParams
from transformers import AutoTokenizer


def main():
    path = os.path.expanduser(
        "~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca")
    tokenizer = AutoTokenizer.from_pretrained(path)

    # 故意设置块大小为 512,与 Sequence.block_size=256 不一致
    llm = LLM(path, enforce_eager=True, tensor_parallel_size=1, kvcache_block_size=512) # kvcache_block_size=256即可避免错误

    # 构造一个超过 256 token 的长 prompt,以触发跨块
    long_text = " ".join(["test"] * 300)  # 约 300 token
    prompt_ids = tokenizer.encode(long_text)
    print(f"Prompt token count: {len(prompt_ids)}")

    # 设置足够大的 max_tokens,确保在 decode 阶段也会跨块
    sampling_params = SamplingParams(temperature=0.6, max_tokens=300, ignore_eos=True)

    outputs = llm.generate([prompt_ids], sampling_params)

    for output in outputs:
        print(f"Completion length: {len(output['token_ids'])}")


if __name__ == "__main__":
    main()

运行会报错如下,验证我们的猜想。

/home/iguochan/miniconda3/bin/conda run -n nano-vllm-env --no-capture-output python /home/iguochan/workspace/remote/example_cross.py 
`torch_dtype` is deprecated! Use `dtype` instead!
Prompt token count: 300
Generating:   0%|                                         | 0/1 [00:00<?, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/home/iguochan/workspace/remote/example_cross.py", line 29, in <module>
[rank0]:     main()
...
"/home/iguochan/workspace/remote/nanovllm/layers/attention.py", line 39, in store_kvcache
[rank0]:     assert slot_mapping.numel() == N
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError
Generating:   0%|                                         | 0/1 [00:00<?, ?it/s]
ERROR conda.cli.main_run:execute(127): `conda run python /home/iguochan/workspace/remote/example_cross.py` failed. (See above for error)

Process finished with exit code 1

所以做一些小改动,Sequence中的block_size我们选择传入。

class Sequence:
    counter = count()

    def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), block_size: int = 256):
        self.block_size = block_size

同时在LLMEngine模块新增config字段,且传入到Sequence中:

class LLMEngine:

    def __init__(self, model, **kwargs):
        # ...
        self.config = config
        # ...

    def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
        if isinstance(prompt, str):
            prompt = self.tokenizer.encode(prompt)
        seq = Sequence(prompt, sampling_params, block_size=self.config.kvcache_block_size)
        self.scheduler.add(seq)

这样就可以避免错误。