随着模型变得越来越复杂和庞大,推理速度优化变得尤为重要,尤其是对于用户期望即时回复的聊天应用。键值 (Key-Value, KV) 缓存用于提升transformer架构模型的响应速度,本文将深入了解它的工作原理及其应用场景。
Transformer 架构概述
在深入探讨 KV 缓存之前,我们需要先简要了解 Transformer 中的注意力机制。从而更方便的讲解 KV 缓存如何优化 Transformer 的推理过程。
本文主要关注用于文本生成的自回归模型(Autoregressive Models)。比如 GPT 系列、Gemini、Claude 等。它们的训练任务非常简单:预测序列中的下一个 token。在推理时,模型接收部分文本,任务是预测文本的后续内容。
从整体上看,大多数 Transformer 由以下几个基本模块组成:
- 分词器 (Tokenizer) :将输入文本拆分为更小的部分,如单词或子词。这一步骤将自然语言转化为模型可以处理的格式。
- 嵌入层 (Embedding Layer) :将分词后的 token 及其在文本中的位置转换为向量。这些向量包含了词汇的语义信息。
- 基本神经网络层:包括 dropout、层归一化(Layer Normalization)和前馈线性层(Feed-Forward Linear Layers)。
- 自注意力模块 (Self-Attention Module) :这是 Transformer 架构中最核心的部分,使模型能够在生成输出时关注输入序列的不同部分,从而有效建模长距离依赖关系。
基本自注意力模块
自注意力机制允许模型在生成下一个 token 时,重点关注输入序列的特定部分。例如,在生成句子“她将咖啡倒入杯中”时,模型可能会更关注“倒”和“咖啡”这两个词,以预测下一个词“入”,因为这些词提供了上下文信息。
从数学角度来看,自注意力的目标是将每个输入 token 转换为一个上下文向量,该向量综合了文本中所有输入的信息。以“她倒咖啡”为例,注意力机制将为每个词生成一个上下文向量。
计算上下文向量时,自注意力会生成三种中间向量:查询 (Query) 、键 (Key) 和 值 (Value) 。以下是如何计算第二个词“poured”的上下文向量的步骤:
- 线性变换:每个输入 token 分别与三个权重矩阵 、 和 相乘,生成查询、键和值向量。这些权重矩阵是在训练过程中学习到的参数,用于捕捉不同的语义关系。
- 计算注意力分数:将查询向量与所有键向量相乘,得到注意力分数。这些分数表示当前查询与所有键的匹配程度。
- 归一化权重:通过 softmax 函数将注意力分数归一化,得到注意力权重。这些权重用于衡量各个值向量的重要性。
- 生成上下文向量:将注意力权重与值向量相乘并求和,得到最终的上下文向量。这一向量包含了与当前 token 相关的所有上下文信息。
以下是 Sebastian Raschka 在其书籍《从零开始构建大语言模型》中提供的基本自注意力模块的代码实现示例,帮助理解其具体操作:
import torch
class SelfAttention_v2(torch.nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
context_vec = attn_weights @ values
return context_vec
Sebastian 的代码在矩阵上操作:他的 forward() 方法中的 x 是三个堆叠在一起的的 x1、x2 和 x3 向量,构成了一个具有三行的矩阵。这允许他简单地将 x 与 W_key 相乘以获得key矩阵,key矩阵是一个由三行组成的矩阵(在例子中是 k1、k2 和 k3)。
这里需要注意的是,在每个前向传递中,我们将键与查询相乘,然后再与值相乘。继续阅读时请记住这一点。
高级自注意力模块
上述描述的是自注意力的基础形式。当前最大的 LLM 通常使用稍加修改的自注意力机制,主要有以下三个方面的改进:
- 因果注意力 (Causal Attention) :模型在预测下一个 token 时,只考虑之前的 token,避免“提前”看到未来的词。这通过在注意力权重矩阵中掩盖未来的信息来实现。例如,在生成句子“她倒咖啡”时,当模型接收到单词“她”并试图预测下一个单词“倒”时,它不应访问“咖啡”与任何其他单词之间的注意力权重,因为“咖啡”这个词尚未出现在文本中。因果注意力通常通过将注意力权重矩阵的“前视”部分设置为零来实现。
- 多头注意力 (Multi-Head Attention) :基础注意力可以被称为单头注意力,意味着只有一组 Wk、Wq 和 Wv 矩阵。增加模型容量的一个简单方法是切换到多头注意力。这归结为拥有多组 W 矩阵,因此,每个输入有多个查询、键和值矩阵,以及多个上下文向量。多头注意力使模型能够从不同的子空间捕捉信息,提高模型的表现力。
此外,一些 Transformer 实现了注意力模块的额外优化,旨在提高速度或准确性。三个流行的优化方法是:
- 分组查询注意力 (Grouped-Query Attention) :不再单独查看每个输入 token,而是将 token 分组,使模型能够一次关注一组相关的词,从而加快处理速度。Llama 3、Mixtral 和 Gemini 使用了这种方法。
- 分页注意力 (Paged Attention) :将注意力分解为“页”或 token 块,因此模型一次处理一页,使其在处理非常长的序列时更快。
- 滑动窗口注意力 (Sliding-Window Attention) :模型仅关注每个 token 周围固定“窗口”内的附近 token,因此它专注于局部上下文,无需查看整个序列。
所有这些最先进的自注意力实现方法都没有改变其基本前提和所依赖的基本机制:始终需要将键乘以查询,然后再乘以值。然而在推理时,这些重复的乘法计算会带来显著的效率低下。
什么是键值 (KV) 缓存?
在推理过程中,Transformer 模型一次生成一个 token。当我们通过传递“ she ”来提示模型开始生成时,它会产生一个词,例如“ poured ”(为了避免干扰,我们继续假设一个 token 是一个词)。然后,我们可以将“She poured”传递给模型,它会生成“coffee”。接下来,我们传递“She poured coffee”,并从模型中获取序列结束 token,表明它认为生成已完成。
这意味着我们已经运行了三次前向传递,每次都将查询与键相乘以获得注意力分数。
在第一次前向传递中,只有一个输入 token(“ she ”),因此只有一个键向量和一个查询向量。我们将它们相乘以获得 q1k1 注意力分数。
接下来,我们将“She poured”传递给模型。它现在看到两个输入 token,因此注意力模块内的计算如下:
我们进行了乘法计算以计算三个项,但 q1k1 其实不需要再计算了!这个 q1k1 元素与之前的前向传递中的相同,因为:
q1是通过将输入(“She”)的嵌入乘以Wq矩阵计算得出的,k1是通过将输入(“She”)的嵌入乘以Wk矩阵计算得出的,- 在推理时,嵌入和权重矩阵都是恒定的,不会改变。
注意注意力分数矩阵中灰色的条目:这些被用零掩盖以实现因果注意力。例如,q1k3 的右上元素在生成第二个词时不会显示给模型,因为我们尚不知道第三个词和 k3。
最后,这是我们第三次前向传递中查询乘以键的计算示意图。
在图中的六个值中,有一半都是已经计算过而不需要再计算的
在推理时,当我们计算键 (K) 和值 (V) 矩阵时,我们将它们的元素存储在缓存中。缓存是一种辅助内存,可以高速检索。随着后续 token 的生成,我们只需计算新 token 的键和值,从而避免重复计算,从而节省时间。
例如,第三次前向传递在使用缓存时将如下所示:
在处理第三个 token 时,我们不需要重新计算前两个 token 的注意力分数。我们可以从缓存中检索前两个 token 的键和值,从而节省计算时间。
评估键值缓存的影响
键值缓存可能对推理时间产生显著影响。这种影响的大小取决于模型的架构。可缓存的计算越多,减少推理时间的潜力就越大。
让我们使用 Hugging Face Hub 上的 EleutherAI 的 GPT-Neo-1.3B 模型来分析 KV 缓存对生成时间的影响。代码对应环境如下:
Google Colab 上的 T4 GPU 执行了此代码,torch==2.5.1+cu121 、 transformers==4.46.2 、Python 3.10.12。
我们将首先定义一个计时器上下文管理器来计算生成时间:
import time
class Timer:
def __enter__(self):
self._start = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self._end = time.time()
self.duration = self._end - self._start
def get_duration(self) -> float:
return self.duration
接下来,我们从 Hugging Face Hub 加载模型,设置分词器,并定义提示:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_text = "Why is a pour-over the only acceptable way to drink coffee?"
最后,我们可以定义运行模型推理的函数:
def generate(use_cache):
input_ids = tokenizer.encode(
input_text,
return_tensors="pt"
).to(device)
output_ids = model.generate(
input_ids,
max_new_tokens=100,
use_cache=use_cache,
)
注意我们传递给 model.generate 的 use_cache 参数:它控制是否使用 KV 缓存。
通过这种设置,我们可以测量有无 KV 缓存的平均生成时间:
import numpy as np
for use_cache in (False, True):
gen_times = []
for _ in range(10):
with Timer() as t:
generate(use_cache=use_cache)
gen_times += [t.duration]
print(
f"Average inference time with use_cache={use_cache}: ",
f"{np.round(np.mean(gen_times), 2)} seconds",
)
最终结果为:
Average inference time with use_cache=False: 9.28 seconds
Average inference time with use_cache=True: 3.19 seconds
如上,缓存带来的加速几乎是三倍。
kv缓存的难点
尽管 KV 缓存能显著提升生成速度,但也带来了内存使用的增加,需要在生产系统中仔细管理。
如果缓存消耗的内存成为问题,可以通过牺牲一些模型准确性来换取额外的内存:
- 序列截断 (Sequence Truncation) :限制最大输入序列长度,从而在牺牲长期上下文的情况下限制缓存大小。在长期上下文相关的任务中,模型的准确性可能会受到影响。
- 减少注意力头的数量 (Pruning Attention Heads) :通过减少模型层数或注意力头的数量,可以降低模型的复杂性和缓存内存需求,从而回收一些内存。然而,这可能会影响模型的表现和准确性。
- 量化 (Quantization) :使用较低精度的数据类型(例如 float16 而不是 float32)进行缓存,以减少内存使用。同样,这可能会对模型的准确性产生一定影响。
生产中的 KV 缓存管理
在具有大量用户的大规模生产系统中,最需要注意的两个问题是缓存失效(何时清除缓存) 和 缓存重用(如何多次使用相同的缓存)。
缓存失效
最常用的三种的缓存失效策略是基于会话的清除、基于生存时间 (TTL) 的失效和基于上下文相关性的策略。
-
基于会话的清除 (Session-Based Clearing) :在用户会话或与模型的对话结束时清除缓存。适用于对话简短且彼此独立的应用。
例如,想象一个客户支持聊天机器人应用,每个用户会话通常代表一个独立的对话,用户在其中寻求特定问题的帮助。在这种情况下,这些缓存内容不太可能再次需要。此时就可以在用户结束聊天或会话因不活动而超时后清除 KV 缓存。
-
基于生存时间 (TTL) 的失效 (Time-To-Live Invalidation) :在一定时间后自动清除缓存内容。适用于当缓存数据的相关性随着时间的推移可预测地减弱时。
例如,考虑一个提供实时更新的新闻聚合应用。缓存的键和值可能只在新闻热门期间相关。实施一个 TTL 策略,例如缓存条目在一天后过期,可以确保对新发展的类似查询快速生成响应,同时旧新闻不会占用内存。
-
基于上下文相关性的失效 (Contextual Relevance-Based Invalidation) :一旦缓存内容对当前上下文或用户交互变得不相关,就会立即清除。适用于先前的上下文对新上下文没有贡献价值。
例如,想象一个作为 IDE 插件工作的编码助手。当用户正在处理一组特定的文件时,缓存应保留。然而,一旦他们切换到不同的代码库,之前的键和值变得不相关,可以删除以释放内存。然而,基于上下文相关性的策略需要准确定位上下文切换发生的事件或时间点。
缓存重用
缓存管理的另一个重要方面是其重用。在某些情况下,曾经生成的缓存可以再次使用,通过避免在不同用户的缓存实例中存储相同的数据,从而加快生成速度并节省内存。
缓存重用的机会通常出现在存在共享上下文和/或需要预热启动的情况下。
-
共享上下文 (Shared Context) :在多个请求共享共同上下文的场景中,可以重用该共享部分的缓存。
例如,在电子商务平台上,某些产品可能有标准描述或规格,多个客户经常询问。这些可能包括产品详情(“55 英寸 4K 超高清智能 LED 电视”)、保修信息(“附带 2 年制造商保修,涵盖零件和劳务。”)或客户说明(“为了获得最佳效果,请使用兼容的墙壁支架安装电视,另售。”)。通过缓存这些共享产品描述的键值对,客户支持聊天机器人将更快地生成对常见问题的回答。
-
预计算缓存 (Precomputed Cache) :可以预计算并缓存初始的 KV 对于频繁使用的提示或查询。
例如,考虑一个语音激活的虚拟助手应用。用户经常以诸如“今天的天气如何?”或“设定一个 10 分钟的计时器。”这样的短语开始互动。助手可以通过预计算并缓存这些常用查询的键值对来更快地响应。
结论
键值 (KV) 缓存是 Transformer 模型中的一种重要技术,通过存储和重用前一步骤生成的键和值向量,减少了冗余计算,显著提升了推理速度。然而,这种加速是需要增加内存使用,在内存有限的系统中,需要通过优化模型或调整缓存策略来平衡。