1. nano-vllm:KV Cache、PagedAttention和nano-vllm的实现

0 阅读13分钟

0. 什么是KV Cache

0. nano-vllm:大模型推理原理和流程中我们阐述了大语言模型作为一种自回归模型的基本工作流程,其主要工作阶段分为:

  1. prefill阶段:模型处理全部的Prompt,进行前向计算,直到生成第一个输出token;这个阶段的关键指标是TTFT(Time To First Token),即生成第一个token所需的时间;
  2. decode阶段:一旦prefill完成,模型进入decode阶段,逐个生成剩余token;这个阶段的关键指标是TPOT(Time Per Output Token),即生成每个响应token所需的平均时间。

0.1 自回归模型

前面提到的decode阶段,就是自回归模型的典型特征:模型上一步的输出会被当作是下一步的输入。而transformer是一种神经网络架构,是实现自回归语言模型当前最强大、最主流的架构。transformer模型的整体架构如下:

如上所示,Scaled Dot-Product Attention(缩放点积注意力)是Transformer模型中最核心、最基本的运算单元,其基本的公式如下:

Attention(Q,K,V)=softmax((QKT)/sqrt(dk))VAttention(Q, K, V) = softmax( (Q K^T) / sqrt(d_k) ) V

其中(具体含义咱们在这就不深究了):

  • Q查询向量,代表当前需要“寻找什么信息”。
  • K键向量,代表历史信息“有什么内容的标签”。
  • V值向量,代表历史信息“具体的内容是什么”。

0.2 没有KV Cache

上图所示,对应token的Attention计算如下:

Att1(Q,K,V)=softmax((Q1K1T)/sqrt(dk))V1Att_1(Q, K, V) = softmax( (Q_1K_1^T) / sqrt(d_k) )V_1
Att2(Q,K,V)=softmax((Q2K1T)/sqrt(dk))V1+softmax((Q2K2T)/sqrt(dk))V2Att_2(Q, K, V) = softmax( (Q_2K_1^T) / sqrt(d_k) )V_1 + softmax( (Q_2K_2^T) / sqrt(d_k) )V_2
Attn(Q,K,V)=softmax((QnK1T)/sqrt(dk))V1+softmax((QnK2T)/sqrt(dk))V2+...+softmax((QnKnT)/sqrt(dk))VnAtt_n(Q, K, V) = softmax( (Q_nK_1^T) / sqrt(d_k) )V_1 + softmax( (Q_nK_2^T) / sqrt(d_k) )V_2 + ... + softmax( (Q_nK_n^T) / sqrt(d_k) )V_n

为了严格保证模型的“因果性”或“自回归属性”,Decoder有Causal Mask,在推理的时候前面已经生成的字符不需要与后面的字符产生attention,所以QKTQK^T矩阵的右上半部会被mask,即推导TokenkToken_k时不应该受到Kk+1,Kk+2,...,Kk+nK_k+1,K_k+2,...,K_k+n对其的影响。

从而观察以上计算的时候,会发现每次会有大量冗余的计算,即原本计算的K和V都被重新计算了一次,而每一步的KV都是确定的,因此完全可以被缓存起来。而每次计算的AttkAtt_k除了和之前的KV相关,只和QkQ_k相关,因此Q没有必要缓存,每次只需计算当前Q即可。

0.3 引入KV Cache

下面的Step1-4即为“Transformers KV Caching Explained”中提供的不使用KV Cache和使用KV Cache的对比:

下面借助“Transformers KV Caching Explained”的例子测试一下:

import numpy as np
import time
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

path = os.path.expanduser(
        "~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path).to(device)

for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=2000)
    times.append(time.time() - start)
  print(f"{'TPOT with' if use_cache else 'TPOT without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")



for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1)
    times.append(time.time() - start)
  print(f"{'TTFT with' if use_cache else 'TTFT without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

结果显示,KV Cache对于TPOT的提升较大,这很好理解,因为其发挥作用就是在自回归解码阶段;而对于预填充阶段的TTFT耗时几乎没有影响。

TPOT with KV caching: 19.32 +- 0.058 seconds
TPOT without KV caching: 123.952 +- 0.363 seconds
TTFT with KV caching: 0.012 +- 0.0 seconds
TTFT without KV caching: 0.011 +- 0.0 seconds

0.3.1 KV Cache带来的问题

KV Cache 本质上是拿空间换时间的操作,因此不可避免带来一些问题:

  1. 显存占用更大:显存增长与序列长度成正比,在资源受限的设备上(比如消费级显卡),显存可能首先被KV Cache耗尽,从而严重限制了批量大小(batch size)和可生产的最大序列长度。
  2. 显存带宽瓶颈:虽然减少了计算量,但是压力转移到显存访问中,推理速度的瓶颈从计算瓶颈转移到内存带宽瓶颈;且频繁地分配和释放缓存,还可能导致显存碎片化。

1. Paged Attention

1.1 KV Cache的内存碎片

早期的大模型推理系统,采用的也是连续的显存分配系统,即将同一个请求的KV Cache存储在一个连续空间里,预分配的机制会在请求开始时,就按最大可能生成的长度(如 2048 token)为每个请求分配一整块连续内存空间,这种连续分配机制存在着严重的显存碎片。结合PagedAttention 的论文《Efficient Memory Management for Large Language Model Serving with PagedAttention》,可以看到,传统的推理系统(以Orca为例)的显存有效占比极低,除绿色部分外皆为浪费;而vLLM引入的PagedAttention机制则使得其KV Cache的使用率达到了96.3%。

1.2 PagedAttention机制

类似于操作系统虚拟内存,vLLM提出了PagedAttention,具体来说,PagedAttention将每个生成序列的KV Cache划分为多个block,每个block中包含固定数量的key和value向量。和虚拟内存的映射表一样,PagedAttention也存在一个映射表,在Attention计算要用到KV Cache的时候,通过Block Table找到这个序列对应的block,进而从block中取出对应的KV向量。其基本图示如下:

1.3 引用计数

论文《Efficient Memory Management for Large Language Model Serving with PagedAttention》中介绍了vLLM支持多样化的推理策略,并在这些推理策略中表现优异,比如并行采样(Parallel Sampling)束搜索(Beam Search) 以及 共享前缀(Shared Prefix) ,其实他们都通过给内存块(Block)添加引用计数,来实现KV Cache的复用,从而显著节省显存、提高吞吐量。其实在vLLM的v1中,束搜索(Beam Search已经被移出核心,并行采样(Parallel Sampling)也并不会发生写时复制,这里就不详述了。

2. nano-vllm的实现

2.1 存储相关模块关系

graph TB
    subgraph "物理层 - BlockManager"
        BM["BlockManager<br/>nanovllm/engine/block_manager.py"]
        BLOCKS["物理Block数组<br/>Block[0..N-1]"]
        FREE["空闲队列<br/>free_block_ids"]
        USED["使用集合<br/>used_block_ids"]
        HASH["哈希映射<br/>hash_to_block_id"]
        
        BM --> BLOCKS
        BM --> FREE
        BM --> USED
        BM --> HASH
    end
    
    subgraph "存储层 - ModelRunner"
        MR["ModelRunner<br/>nanovllm/engine/model_runner.py"]
        KVCACHE["KV Cache张量<br/>[2, layers, blocks, block_size, heads, dim]"]
        
        MR --> KVCACHE
    end
    
    subgraph "逻辑层 - Sequence"
        SEQ["Sequence<br/>nanovllm/engine/sequence.py"]
        BLOCKTABLE["逻辑块表<br/>block_table: list[int]"]
        TOKENS["Token序列<br/>token_ids: list[int]"]
        
        SEQ --> BLOCKTABLE
        SEQ --> TOKENS
    end
    
    subgraph "计算层 - Attention"
        ATT["Attention<br/>nanovllm/layers/attention.py"]
        KCACHE["k_cache视图"]
        VCACHE["v_cache视图"]
        
        ATT --> KCACHE
        ATT --> VCACHE
    end
    
    BLOCKTABLE -.->|映射到| BLOCKS
    BLOCKS -.->|数据存储在| KVCACHE
    KVCACHE -.->|层视图| KCACHE
    KVCACHE -.->|层视图| VCACHE
    
    style BLOCKS fill:#e1f5ff
    style BLOCKTABLE fill:#fff4e1
    style KVCACHE fill:#ffe1e1

存储相关的模块和作用如下:

  1. 物理BlockBlockManager管理,包含元数据但不存储实际的KV数据
  2. 逻辑Block体现在Sequence.block_table中,是物理block_id的有序列表
  3. 真实存储ModelRunner.kv_cache张量中,按物理block_id索引
  4. Attention层通过k_cachev_cache视图访问对应层的KV数据
  5. slot_mapping提供token到物理存储位置的精确映射,实现PagedAttention

这种设计实现了逻辑与物理的清晰分离:Sequence关注逻辑序列结构,BlockManager管理物理资源分配,ModelRunner提供实际存储,Attention层专注于计算。

2.2 初始化阶段

graph TB
    subgraph "初始化阶段"
        A[ModelRunner.__init__] --> B[allocate_kv_cache]
        B --> C[计算GPU可用内存]
        C --> D[分配KV Cache张量]
        D --> E[创建BlockManager]
    end

在nano-vllm中,采用预分配的方式一次性申请内存,其申请逻辑如下:

def allocate_kv_cache(self):
    config = self.config
    hf_config = config.hf_config
    free, total = torch.cuda.mem_get_info()
    used = total - free
    peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
    current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
    num_kv_heads = hf_config.num_key_value_heads // self.world_size
    head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
    block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
    config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
    assert config.num_kvcache_blocks > 0
    self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
    layer_id = 0
    for module in self.model.modules():
        if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
            module.k_cache = self.kv_cache[0, layer_id]
            module.v_cache = self.kv_cache[1, layer_id]
            layer_id += 1

以上代码要想理解,得先看一下我们两个参数配置,一个是Config模块中的配置:

class Config:
    # ...
    tensor_parallel_size: int = 1
    gpu_memory_utilization: float = 0.9
    kvcache_block_size: int = 256
    num_kvcache_blocks: int = -1
    # ...

我们只摘取与我们相关的参数,其中:

  • tensor_parallel_size:张量并行大小,1-8之间,指定使用的GPU数量进行模型并行
  • gpu_memory_utilization:GPU内存利用率,0.0-1.0之间,指定KV Cache可使用的GPU内存比例;
  • kvcache_block_size:KV Cache块大小,每个块包含的token数量(必须能被256整除);
  • num_kvcache_blocks:KV Cache总块数,根据GPU内存动态计算(-1表示未初始化);

另一个就是模型的config.json文件,在我们例子中模型文件夹下~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca下:

{                                                                                                                                                                                                                                                                                  
  "architectures": [
    "Qwen3ForCausalLM"
  ],  
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 40960,
  "max_window_layers": 28, 
  "model_type": "qwen3",
  "num_attention_heads": 16, 
  "num_hidden_layers": 28, 
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000,
  "sliding_window": null,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.51.0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

所以,内存计算的详解如下:

1. GPU内存状态获取

free, total = torch.cuda.mem_get_info()           # 获取空闲和总内存  
used = total - free                               # 计算已使用内存  
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]    # 历史峰值  
current = torch.cuda.memory_stats()["allocated_bytes.all.current"] # 当前分配

系统通过torch.cuda.mem_get_info()获取GPU内存的实时状态,并结合内存统计信息计算可用于KV Cache的内存空间。

2. 单块内存计算

num_kv_heads = hf_config.num_key_value_heads // self.world_size  
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)  
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize

关键计算公式

  • 2:Key和Value两个张量
  • hf_config.num_hidden_layers:模型层数,这里是28
  • self.block_size:每块token数(默认256)
  • num_kv_heads:KV头数(考虑tensor并行),因为我只有单卡,所以这里是 8/1=8
  • head_dim:头维度,如上是128
  • hf_config.torch_dtype.itemsize:数据类型字节数,因为torch_type是bfloat16,所以应该是两个字节;

所以block_bytes计算出来是 blockbytes=2×28×256×8×128×2=28×1024×1024block_bytes = 2 \times 28 \times 256 \times 8 \times 128 \times 2 = 28 \times 1024 \times 1024,也就是28M字节。

3. 可分配块数计算

config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes

系统根据配置的gpu_memory_utilization(默认0.9)计算可用内存,并除以单块大小得到可分配的块数。这里的公式有点难以理解:

  • total * gpu_memory_utilization:预留90%的显存给本进程;(本机total大概是23.48G,乘以0.9就是21.13G左右)
  • peak:从模型warmup开始计算本进程申请的显存峰值(断点调试时,peak大概是1.58G);
  • used:当前整个机器上所有的显存占用,包含其他服务占用的显存+本进程现阶段占用的显存(断点调试时,used大概是3.69G);
  • current:本进程现阶段占用的显存(断点调试时,current大概是1.14G);

total * config.gpu_memory_utilization - used这很好理解,但是为什么要再减去peak-current呢?其实就是预留peak-current的显存给框架,害怕框架后面有什么操作还需要达到这个峰值大小。根据以上的调试数据,最后可用的显存在17G = 17408M,所以对于28M字节的block块大小,可以一共有621块block,即config.num_kvcache_blocks = 621

4.KV Cache张量结构

6维张量分配

self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks,   
                           self.block_size, num_kv_heads, head_dim)

根据前面的关键计算公式那里,kv_cache有6个维度:

  • 2:Key和Value两个分离的张量
  • hf_config.num_hidden_layers:模型层数
  • config.num_kvcache_blocks:物理块数量
  • self.block_size:每块包含的token数(默认256)
  • num_kv_heads:KV注意力头数
  • head_dim:每个注意力头的维度

注意这里的公式比上面计算字节数少了一个hf_config.torch_dtype.itemsize,这是因为框架内部会根据dtype来确认这个大小,所以不用担心此时的torch.empty操作申请的显存不对。

以下是我打断点调试时的显存申请视图,可以发现,在进行torch.empty操作时,大约申请了17GB的显存。

2.3 申请流程

graph TB
    subgraph "请求处理"
        F[用户请求] --> G[创建Sequence]
        G --> H[Scheduler.add]
        H --> I[Scheduler.schedule]
    end
    
    subgraph "内存分配"
        I --> J{Prefill阶段?}
        J -->|是| K[BlockManager.allocate]
        J -->|否| L[BlockManager.may_append]
        K --> M[分配物理块]
        L --> N[扩展序列块]
    end
    
    subgraph "注意力计算"
        M --> O[ModelRunner.prepare_prefill]
        N --> P[ModelRunner.prepare_decode]
        O --> Q[设置slot_mapping]
        P --> Q
        Q --> R[Attention.forward]
        R --> S[FlashAttention计算]
        S --> T[存储KV到cache]
    end
    
    subgraph "块管理"
        T --> U[更新Block引用计数]
        U --> V[检查序列完成]
        V -->|完成| W[BlockManager.deallocate]
        V -->|继续| X[下一轮decode]
    end

在推理过程中,请求到来后有关KV Cache的请求过程基本如上所示。下面我们将分步拆解一下每层重点的操作。

2.3.1 prefill

nano-vllm的调度器不支持Chunked Prefill的调度模式,其调度器比较简单,首先处理prefill阶段的请求,然后处理decode阶段的请求,从代码可以看出来,在decode阶段来了一个其他请求,调度器也会优先处理新到的请求。

1.Scheduler的prefill序列
def schedule(self) -> tuple[list[Sequence], bool]:
    # prefill
    scheduled_seqs = []
    num_seqs = 0
    num_batched_tokens = 0
    while self.waiting and num_seqs < self.max_num_seqs:
        seq = self.waiting[0]
        if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
            break
        num_seqs += 1
        self.block_manager.allocate(seq)
        num_batched_tokens += len(seq) - seq.num_cached_tokens
        seq.status = SequenceStatus.RUNNING
        self.waiting.popleft()
        self.running.append(seq)
        scheduled_seqs.append(seq)
    if scheduled_seqs:
        return scheduled_seqs, True

除了做了一些状态判断、状态设置,waiting队列的出队,running队列的入队,schedule主要做了进行了self.block_manager.allocate(seq)操作。

2.BlockManager 显存申请
def allocate(self, seq: Sequence):
    assert not seq.block_table
    h = -1
    cache_miss = False
    for i in range(seq.num_blocks):
        token_ids = seq.block(i)
        h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
        block_id = self.hash_to_block_id.get(h, -1)
        if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
            cache_miss = True
        if cache_miss:
            block_id = self.free_block_ids[0]
            block = self._allocate_block(block_id)
        else:
            seq.num_cached_tokens += self.block_size
            if block_id in self.used_block_ids:
                block = self.blocks[block_id]
                block.ref_count += 1
            else:
                block = self._allocate_block(block_id)
        if h != -1:
            block.update(h, token_ids)
            self.hash_to_block_id[h] = block_id
        seq.block_table.append(block_id)

可以看到,在这里会通过hash计算和token_ids的匹配,看看是否有相同的tokens,如果有的话会引用计数+1。然后通过seq.block_table映射此时的物理块id。

然后,在step函数中会继续调用ModelRunner.run,然后在prepare_prefill阶段,slot通过简单的线性计算与block_id建立对应关系。

ModelRunner 数据准备

核心映射公式

ModelRunner.prepare_prefill()中,slot_mapping的计算逻辑如下:

for i in range(seq.num_cached_blocks, seq.num_blocks):  
    start = seq.block_table[i] * self.block_size  
    if i != seq.num_blocks - 1:  
        end = start + self.block_size  
    else:  
        end = start + seq.last_block_num_tokens   
    slot_mapping.extend(list(range(start, end)))

基础计算公式:

slot = block_id * block_size + token_offset_in_block  
  • block_id: 来自seq.block_table[i],是物理块标识符
  • block_size: 默认256,每个块包含的token数量
  • token_offset_in_block: 块内token的偏移量(0到block_size-1)

实际示例,假设:

  • block_size = 256
  • seq.block_table = [5, 12, 8]
  • 序列长度为300个token

计算过程:

Block 0 (block_id=5): slots 5*2565*256+255 = 1280-1535  
Block 1 (block_id=12): slots 12*25612*256+43 = 3072-3115 (最后一块只有44个token)  
关键数据结构 Sequence.block_table

存储逻辑位置到物理block_id的映射:

self.block_table = []  # [5, 12, 8, ...]
Context.slot_mapping

存储每个token的物理存储slot位置:

slot_mapping: torch.Tensor | None = None  # [1280, 1281, ..., 3115]
Attention 存储访问机制

在Attention层,通过slot_mapping直接索引KV Cache张量:

def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):  
    context = get_context()  
    if k_cache.numel() and v_cache.numel():  
        store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)

Triton内核使用slot作为索引直接访问物理存储位置。

2.3.2 decode

1. Scheduler调度Decode序列

Scheduler.schedule()的decode阶段,系统处理运行队列中的序列:

# decode  
while self.running and num_seqs < self.max_num_seqs:  
    seq = self.running.popleft()  
    while not self.block_manager.can_append(seq):  
        if self.running:  
            self.preempt(self.running.pop())  
        else:  
            self.preempt(seq)  
            break  
    else:  
        num_seqs += 1  
        self.block_manager.may_append(seq)  
        scheduled_seqs.append(seq)

调度逻辑

  • 从运行队列取出序列进行decode
  • 检查是否可以扩展序列,必要时抢占其他序列
  • 调用BlockManager.may_append()扩展KV Cache
2. BlockManager动态扩展

BlockManager.may_append()处理序列的KV Cache扩展:

def may_append(self, seq: Sequence):  
    block_table = seq.block_table  
    last_block = self.blocks[block_table[-1]]  
    if len(seq) % self.block_size == 1:  
        assert last_block.hash != -1  
        block_id = self.free_block_ids[0]  
        self._allocate_block(block_id)  
        block_table.append(block_id)  
    elif len(seq) % self.block_size == 0:  
        assert last_block.hash == -1  
        token_ids = seq.block(seq.num_blocks-1)  
        prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1  
        h = self.compute_hash(token_ids, prefix)  
        last_block.update(h, token_ids)  
        self.hash_to_block_id[h] = last_block.block_id  
    else:  
        assert last_block.hash == -1

扩展策略

  • 跨块边界:当序列长度跨块边界时分配新物理块
  • 块完成:当块填满时计算哈希并加入prefix cache
  • 块内扩展:在同一块内继续填充,无需额外操作
3. ModelRunner数据准备

ModelRunner.prepare_decode()为decode阶段准备执行数据:

def prepare_decode(self, seqs: list[Sequence]):  
    input_ids = []  
    positions = []  
    slot_mapping = []  
    context_lens = []  
    for seq in seqs:  
        input_ids.append(seq.last_token)  
        positions.append(len(seq) - 1)  
        context_lens.append(len(seq))  
        slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)  
    input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)  
    positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)  
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)  
    context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)  
    block_tables = self.prepare_block_tables(seqs)  
    set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)  
    return input_ids, positions

关键准备

  • 新token位置:计算新token的slot位置seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
  • 序列长度:记录每个序列的当前长度用于attention计算
  • 块表准备:为FlashAttention准备块表张量
4. Context设置Decode模式

set_context()设置decode阶段的执行上下文:

set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)

Decode上下文特点

  • is_prefill = False:标识decode阶段
  • context_lens:每个序列的长度,用于FlashAttention
  • slot_mapping:新token的存储位置
  • block_tables:完整的块表映射
5. Attention层Decode计算

Attention.forward()在decode阶段使用专门的FlashAttention接口:

else:    # decode  
    o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,  
                                cache_seqlens=context.context_lens, block_table=context.block_tables,   
                                softmax_scale=self.scale, causal=True)

Decode优化

  • KV Cache复用:直接从cache读取历史KV数据
  • 高效attention:使用flash_attn_with_kvcache优化单token生成
  • 块表索引:通过block_tablecache_seqlens高效访问KV数据