transformer中KV需要缓存而Q不需要缓存的看法?

774 阅读2分钟

token的生成过程

此处假设一个词就是一个token

## 给LLM一个生成任务:请介绍以下RAG是什么?
LLM的回答如下:
RAG
RAG是
RAG是检
RAG是检索
RAG是检索增
RAG是检索增强
RAG是检索增强生
RAG是检索增强生成

在下一轮推理迭代过程中,本轮生成的token会拼加上之前生成的tokens,一起作为input输入到模型中生成下一个token,比如"RAG是检" + "索" -> "增"。

从token的生成流程我们可以看出,当前token的生存只依赖于以前的token,到这里可能会有疑惑token与以前的tokens有关也就是与之前的attention的QKV有关那么为什么只cache KV不cache Q呢? 下面我从attention计算角度进步分析KV可以被cache而Q不需要。 在自注意力机制中,查询(Q)、键(K)和值(V)之间的计算是非常关键的,特别是当模型在生成新的 token 时。每个新生成的 token 对应一个新的查询向量 qnq_n,而这个查询向量需要与之前的所有键向量进行运算来计算注意力权重,好在我们只关注语句的上下文的上部分信息即只与前面生成的token有关,decoder中有一处Masked Multi-Head Attention,将矩阵右上三角设置为**-inf(负无穷大)**  ,这样方便softmax的时候置为0,屏蔽未来数据的计算。

image.png

自注意力计算过程:

  1. 查询(Q)生成:每个新的输入 token xnx_n​ 通过与权重矩阵 WQW_Q​ 相乘生成查询向量 qnq_n

    qn=xnWQq_n = x_n W_Q

  2. 计算与所有键 KK 的注意力分数:然后新的查询 qnq_n 会与所有之前生成的键 k1k_1,k2k_2,…,knk_n 进行运算(通常是计算点积)来得到注意力得分。

    score(qn,ki)=qnKiTscore(q_n,k_i)=q_nK_i^T

    其中 kik_i 是第 i 个键向量。

  3. 缓存键向量的必要性:对于每个新的查询向量 qnq_n,我们不需要重新计算之前所有键的向量 k1k_1,k2k_2,…,kn1k_{n-1} ,因为这些键的计算已经完成,并且它们会在生成过程中不断被使用。因此,**缓存这些键向量k1k_1,k2k_2,…,kn1k_{n-1} ** 可以显著减少每次生成新 token 时的计算量。

  4. 当然上述只是可以缓存的一部分原因,此处只是减少了KiK_i的重复计算也就是xiWKx_i W_K计算,同理Value也是一样,这还远远不够,下文直接引用一位博主的分享。 QKTQK^T = [q1,q2,...,qn]KT[q_1,q_2,...,q_n]*K^T

  5. 矩阵乘法的每个元素是行向量与列向量的点积:

具体来说,矩阵 ( QK^T ) 的第 ( (i, j) ) 个元素是查询向量 ( q_i ) 和键向量 ( k_j ) 的点积:

(QKT)ij=qikj=t=1dkqitkjt(QK^T)_{ij} = q_i \cdot k_j = \sum_{t=1}^{d_k} q_{it} \cdot k_{jt}
  1. 展开矩阵乘法:
QKT=[q1k1Tq1k2Tq1kn1Tq1knTq2k1Tq2k2Tq2kn1Tq2knTq3k1Tq3k2Tq3k3Tq3knTqn1kn1Tqn1knTqnk1Tqnk2Tqnk3TqnknT]QK^T = \begin{bmatrix} {q_1 \cdot k_1^T} & \textcolor{red}{q_1 \cdot k_2^T} & \dots & \textcolor{red}{q_1 \cdot k_{n-1}^T} & \textcolor{red}{q_1 \cdot k_n^T} \\ q_2 \cdot k_1^T & \textcolor{red}{q_2 \cdot k_2^T} & \dots & \textcolor{red}{q_2 \cdot k_{n-1}^T} & \textcolor{red}{q_2 \cdot k_n^T} \\ q_3 \cdot k_1^T & q_3 \cdot k_2^T & \textcolor{red}{q_3 \cdot k_3^T} & \dots & \textcolor{red}{q_3 \cdot k_n^T} \\ \vdots & \vdots & \vdots & \textcolor{red}{q_{n-1} \cdot k_{n-1}^T} & \textcolor{red}{q_{n-1} \cdot k_n^T} \\ q_n \cdot k_1^T & q_n \cdot k_2^T & q_n \cdot k_3^T & \dots & q_n \cdot k_n^T \end{bmatrix}

注意:此处的矩阵的最后一行输入的xnx_n通过WQW_Q计算得到qnq_nqnq_n分别与k1Tk_1^Tk22k_2^2 ... knTk_n^T计算,其中1到n-1的kiTk_i^T在上面几轮都出现过,因此可以缓存下来加速推理,避免重复计算kik_i,但矩阵的最后一列中的knTk_n^Tq1q_1q2q_2 ... ,qn1q_{n-1}都有计算关系,那么qq是不是也需要缓存呢?而且从矩阵计算看后续的xnx_n对以前的xix_i注意力分数产生了影响,但在真实的推理中我们并不知道未来的token可能是什么,未来的token xnx_n也不应该影响过去的注意力计算。在每一轮计算中有右上三角红色部分对我们并没有什么好处反而有影响,为了屏蔽这部分影响我们将矩阵右上三角设置为**-inf(负无穷大)**  ,这样方便softmax的时候置为0。这就是为什么K需要缓存而Q不需要缓存的原因,同理V需要缓存也是一样的原因

  1. 查询矩阵 Q 和 键矩阵 K 的点积:
QKT=[q1q2qn][k1Tk2TknT]QK^T = \begin{bmatrix} q_1 \\ q_2 \\ \vdots \\ q_n \end{bmatrix} \begin{bmatrix} k_1^T & k_2^T & \dots & k_n^T \end{bmatrix}
  1. 缩放操作:
QKTdk\frac{QK^T}{\sqrt{d_k}}
  1. Softmax 操作:
Attention Weights=Softmax(QKTdk)\text{Attention Weights} = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)
  1. 下面是引用几位博主的图解attention计算图解

image.png

引用自:"为什么KV Cache只需缓存K矩阵和V矩阵,无需缓存Q矩阵?",博主讲解的更易懂,安利!!

image.png

image.png

图解

KV 可以被cache源于token的attention仅依赖于以前的token KV计算, masked屏蔽了未来tokens对注意力计算的影响。

tokens计算的过程

kvCache.gif 引用自: Transformers KV Caching Explained

参考

  1. medium.com/@joaolages/…
  2. blog.csdn.net/wlxsp/artic…
  3. blog.csdn.net/qq_35054222…