大模型推理加速-KV Cache

1,598 阅读11分钟

1.LLM推理背景介绍

大型语言模型(llm)已经在广泛的自然语言处理应用程序中表现出了非常显著的泛化性,例如内容创建、内容总结、对话系统等。

现有的LLM模型通常是Transofmer架构的改进。例如自回归LLM为Transformer中的Decoder-only架构。Transorformer Decoder由多个Transformer Block组成,每个Transformer Block包含两个子模块,分别是Multi-Head Attention(MHA)和Feed Forward Neural Network(FFNN),其中MHA模块是Transformer的核心模块,负责对输入序列进行自注意力计算,得到新的序列表示。下面是自回归LLM推理流程。

image.png

自回归LLM,在每个step中,拼接上一次迭代生成的单词(黄色)作为新的输入来预测下一个单词。如上图 step1 针对用户的输入"中国的首" 生成第一个输出 token "都", 第一步对原始 token 的处理,领域内我们称作Prefill预填充阶段,后续自回归的每一步我们称作 Decode解码阶段 ,这两个阶段的特点,我们后续会重点说明,这里不再展开.

另外上图的TransformerBlock模块中最核心的是Attention计算层。在标准Attention计算注意力过程中,所有的输入单词都会与其它单词特征进行向量点积得到相似度计算(或重要度分数),再根据相似度分数,对其它单词特征进行求和规约,得到当前单词新的特征表示。

我们将模型网络结构聚焦到TransformerBlock中的单个Attention计算层,来简化模型推理流程,观察原始自回归LLM推理prefilldecoding 各自的特点,来说明为什么会出现基于 kv cache 的解决方案。下面是简化后的模型推理流程。(备注:对于 Attention 计算不熟悉的同学可以参考文章 Flash Attention v1里面的标准 Attention 说明介绍 )

image.png 在每个step迭代计算中,自回归LLM,除了 step1,其他的 step 只会使用最后一个输入单词 token(由上一次迭代生成)对应的Attention计算结果来进行预测,因此

  • step1 Prefill 需要计算所有token 之间的注意力
  • step2 Decode之后的每一次迭代,理论上仅需计算上一次迭代最后一个 token 完整的Attention,对于上一次最后一个 token 的Attention 计算, 仅仅依赖历史 token 的 K 和 V

为了更直观的说明,下面我们上图,基于展开Attention计算过程说明,具体如下。

image.png 我们观察到最后一个单词的Attention计算结果,只与前向单词序列的K,V向量有关,而K,V向量是由每个单词特征经过线程变换得到的,这里每一个单词特征又依赖于上一层各自的Attention计算结果,即每一个单词都必须参与上一层网络的Attention计算,才能得到当前层最后一个单词的Attention计算结果。

总结,在自回归LLM中,尽管我们只需要最后一个单词的Attention计算结果,但仍需要输入序列中每一个单词都参与所有网络层的Attention计算,迭代重复此过程,将会导致大量前序单词的重复Attention计算和KV线性映射的计算

因此我们可以考虑用空间换时间,将每一次迭代的每一层Attention计算的K,V缓存起来,在后续迭代过程中,直接使用前一次迭代缓存的前序单词的K,V向量进行Attention计算,同时又将当前新的KV向量加入KV Cache中,供下一次迭代计算使用,从而避免大量前序单词的Attention和KV线性映射的重复计算。这种技术就叫KV Cache。

2.KV Cache的计算流程

由于使用了KV Cache,当用户输入Prompt提示时,LLM会根据Prompt提示生成第一个单词,然后根据第一个单词的Attention计算结果,生成第二个单词,以此类推,直到生成用户期望的输出序列。

  • 上述生成第一个单词的过程,因为只负责写KV Cache,所以称为Prefill
  • 上述迭代生成后续单词的过程,主要为读KV Cache,不断生成新的单词,所以称为Decode

在Decode阶段,使用KV Cache的Attention计算流程如下,在进行Attention计算时,直接访问KV Cache中的K,V向量,而无需额外的计算。

image.png

3.KV Cache存在的问题

LLM的部署成本除了广泛研究的模型参数存储和注意层的二次代价瓶颈外,KV Cache的问题的也越来越突出。 KV cache主要存在以下两个问题:

3.1 占用显存过大

例如,输入batchsize为128、序列长度为1024的300亿参数模型需要180GB的KV缓存。下面是KV cache存储占用的分析。

假设输入序列的长度为N1N1 ,输出序列的长度为N2N2 ,transformer层数为L,单Attention head中向量大小为dd,KV Attention head个数为HH,则KV Cache 需要存储 N1+N2N1+N2 个 KV向量,形状为 [B,H,n,d][B, H, n, d], KV序列长度nn峰值为 N1+N2N1+N2 ,以float16来保存KV cache,那么KV cache的峰值显存占用大小为:

2LBH(N1+N2)dbpe2*L*B*H*(N1+N2)*d*bpe(这里第一个 2 表示 Key和Value的cache)。

以GPT3-175B为例,对比KV cache与模型参数占用显存的大小。模型配置如下:

模型名参数量LN1N2HddtypeB参数显存KV Cache峰值显存
GPT3175B965123296128fp1664324GB153GB
LLama270B80204820488128fp1664130GB80GB
  • GPT3 模型占用显存大小为350GB。假设批次大小b=64 ,输入序列长度N1=512 ,输出序列长度N2=32 ,则KV cache 峰值占用显存为 2LBH(N1+N2)dbpe2*L*B*H*(N1+N2)*d*bpe = 164,282,499,072 bytes ≈ 164 ,则batch_size=64时,GPT3-175B的KV Cache峰值显存大约是模型参数显存的0.5倍
  • llama2模型占用显存大小为130GB。假设批次大小b=64 ,输入序列长度N1=2048 ,输出序列长度N2=2048 ,则KV cache 峰值占用显存为 2LBH(N1+N2)dbpe2*L*B*H*(N1+N2)*d*bpe = 85,899,345,920 bytes ≈ 80GB,则batch_size=64时,llama2-70B的KV Cache峰值显存大约是模型参数显存的0.6倍

为了解决上述KV Cache存储占用过大的问题,提出了的KV Cache相关的稀疏化技术

3.2 显存利用率低

现有系统(HuggingFace 默认实现是pytorch的内存分配策略)的KV Cache由于内存碎片化和过度预留(为即将到来的请求预留Cache)而浪费了60%至80%的内存。 当KV Cache多次在显存上分配大量连续内存时,容易导致内存碎片化,从而浪费显存。

KV Cache存在三种类型的内存浪费:reserved、内部碎片和外部碎片。下图是kv cache的内存占用示例。

image.png

  • reserved:KV Cache预留的显存
  • 内部碎片:为KV Cache单请求内部碎片化的显存
  • 外部碎片:为KV Cache与推理请求的Cache之间碎片化的显存

下图统计不同推理系统中的内存reserved和内部碎片以及实际token的KV存储占用情况。

image.png 为了解决上述kv cache显存利用率不高,现有系统的内存管理产生内存浪费的问题,引入了paged attention。其思想本质是连续固定的大内存碎片问题,必定会转向离散分块的池化管理机制

4.KV Cache的优化

4.1 稀疏化

4.1.1 Window Attention

一种直观的方法,被称为窗口注意力机制,在当前时刻token的历史KV序列上保持一个固定大小的滑动窗口。即当前token只与其历史窗口内的token进行Attention计算。

基于窗口的注意力计算,可以将原始的KV Cache的峰值显存下降N/n倍,其中N为最大的序列长度,对于llama2 N为4096,假设n=4,则应用窗口注意力机制,可使峰值显存下降4096/4=1024倍。

image.png

4.1.2 Streaming LLM

使用Streaming LLM的注意力计算,可以使KV Cache峰值显存下降N/(n1+n2),其中N为最大的序列长度,n1为初始token个数,n2为窗口大小。对于llama2 N为4096,假设n1=4,n2=4,则应用Streaming LLM,可使KV Cache峰值显存下降4096/(4+4)=512倍。

基于窗口(Window)的注意力机制,尽管确保了恒定的内存占用和解码速度,但模型性能容易崩溃,甚至只是忽略序列中第一个token的KV,也会导致性能下降。

观察Llama-2-7B的注意力图,发现模型一致地关注初始tokens。移除这些初始tokens的KV将导致注意力分数的分布发生重大变化。

此原因归因于Softmax操作,它要求所有上下文token的注意力分数总和为1。因此,即使当前查询在许多历史token中没有强匹配(强相关),模型仍然需要将这些不需要的注意力值分配到某个地方,以便其汇总为1。初始token作为汇聚的主要token背后的原因是直观的:由于自回归语言建模的性质,初始token对几乎所有后续的token都是可见的,使它们更容易被训练为注意力汇聚。

基于上述发现,StreamingLLM框架提出,只是保持注意力汇聚的主要token的KV(只有4个初始token足够)和滑动窗口的KV,以锚定注意力计算,稳定模型的性能。即StreamingLLM用到的KV Cache=序列初始token的KV+滑动窗口的KV。

image.png

4.1.3 H2O

使用H2O的注意力计算,可以使KV Cache峰值显存下降N/(n1+n2),其中N为最大的序列长度,n1为贪心搜索保留的KV个数, n2为窗口大小。对于llama2 N为4096,假设n1=4,n2=4,则应用H20,可使KV Cache峰值显存下降4096/(4+4)=512倍。

我们发现,在计算注意力得分时,只有少数的词语(称之为重点词,H2)占据了大部分的价值。研究表明,这些重点词与文本数据中经常同时出现的单词表现出很强的相关性,一旦去除这些重点词,模型的性能会显著下降。

基于这一发现,H2O算法采用了一种贪心的KV缓存淘汰策略,它动态地保留了最近的滑动窗口和重点词。H2O对比StreamingLLM而言,从一种的固定的窗口方法+初始词的KV Cache方法,演进成一种基于固定窗口+动态决策重点词的KV Cache方法,模型性能更好。 具体的, H2O算法在每个解码step中,都求解所有历史单词的token的累计注意力分数。然后,根据累计注意力分数,选择累计注意力分数最高的n个token作为重点词,并保留它们。其它的token则被驱逐出KV Cache。

下面是H2O算法在第4和第5个解码step上,基于贪心策略筛选出3个重点词的示意图。

image.png

4.2 paged attention

4.2.1 Paged KV Cache

Paged KV Cache就是将KV Cache分块连续存储,通过块表维护逻辑块和物理块之间的映射关系。这样,显存空间就可以更有效地利用,从而减少内存浪费,提高推理的batchsize,增大LLM推理的吞吐量。使用了Paged KV Cache的Attention计算就叫Paged Attention。

image.png 图中①②③分别表示Prompt单词序列、生成的第一个单词、生成第二个单词。其中prompt序列(①)被拆分成了两个逻辑块,分别存储在第7和第1个物理块中,通过块表去映射。后续生成的单词特征向量(②和③),也根据块表映射到物理块中。

在Attention计算的时候,Block Table会转换成张量,通过索引的方式,得到逻辑块对应的物理块号,例如: block_tables = [7,1],表示当前序列的逻辑块0存储在物理块7中(共连续存储了block_size个token向量),逻辑块1存储在物理块1中(共连续存储了block_size个token向量)。 实际上block_tables是一个二维张量,表示多个序列的block_tables,例如: block_tables = [[7,1],[2,3]][[7,1],[2,3]],表示序列0和序列1的逻辑块和物理块映射关系。

4.2.2 Paged Attention计算流程

Paged Attention计算流程如下:

image.png

ref:

medium.com/@joaolages/…

arxiv.org/abs/2309.06…

jieyibu.net/a/336