[论文翻译]DuoAttention: EFFICIENT LONG-CONTEXT LLM INFERENCE WITH RETRIEVAL AND STRE

216 阅读21分钟

Doc2X | 高效 PDF 转换助手 Doc2X 提供精准的 PDF转Docx、PDF转Markdown、PDF转HTML 功能,支持 多栏识别、表格解析、公式编辑,满足多样化文档处理需求。 Doc2X | Efficient PDF Conversion Assistant Doc2X offers precise PDF to Docx, Markdown, and HTML conversion, supporting multi-column recognition, table parsing, and formula editing for versatile document handling. 👉 了解 Doc2X 的独特功能 | Explore Doc2X Features

原文链接:arxiv.org/pdf/2410.10…

DuoAttention: EFFICIENT LONG-CONTEXT LLM INFERENCE WITH RETRIEVAL AND STREAMING HEADS

DuoAttention: 带有检索和流式头的高效长上下文LLM推理

Guangxuan Xiao 1  {}^{1 * }\; Jiaming Tang 1  {}^{1}\; Jingwei Zuo 2  {}^{2}\; Junxian Guo 1,3{}^{1,3}

肖广轩 1  {}^{1 * }\; 唐嘉明 1  {}^{1}\; 左静伟 2  {}^{2}\; 郭俊贤 1,3{}^{1,3}

Shang Yang 1  {}^{1}\; Haotian Tang 1  {}^{1}\; Yao Fu 4  {}^{4}\; Song Han 1,5{}^{1,5}

杨尚 1  {}^{1}\; 唐浩天 1  {}^{1}\; 付尧 4  {}^{4}\; 韩松 1,5{}^{1,5}

1\frac{1}{} MIT 2{}^{2} Tsinghua University 3{}^{3} SJTU 4{}^{4} University of Edinburgh 5{}^{5} NVIDIA

1\frac{1}{} 麻省理工学院 2{}^{2} 清华大学 3{}^{3} 上海交通大学 4{}^{4} 爱丁堡大学 5{}^{5} 英伟达

github.com/mit-han-lab…

ABSTRACT

摘要

Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which primarily focus on recent tokens and attention sinks-referred to as Streaming Heads-do not require full attention. Based on this insight, we introduce DuoAttention, a framework that only applies a full KV cache to retrieval heads while using a light-weight, constant-length KV cache for streaming heads, which reduces both LLM's decoding and pre-filling memory and latency without compromising its long-context abilities. DuoAttention uses a lightweight, optimization-based algorithm with synthetic data to identify retrieval heads accurately. Our method significantly reduces long-context inference memory by up to 2.55×{2.55} \times for MHA and 1.67×{1.67} \times for GQA models while speeding up decoding by up to 2.18×{2.18} \times and 1.50×{1.50} \times and accelerating pre-filling by up to 1.73×{1.73} \times and 1.63×{1.63} \times for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention. Notably, combined with quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context length on a single A100 GPU. Code is provided in the link

部署长上下文大语言模型(LLMs)是必要的,但面临着显著的计算和内存挑战。缓存所有注意力头中的键和值(KV)状态会消耗大量内存。现有的KV缓存剪枝方法要么损害LLMs的长上下文能力,要么只能提供有限的效率提升。在本文中,我们发现只有一小部分注意力头,即检索头,对于处理长上下文是关键的,并且需要对所有标记进行全注意力处理。相比之下,其他主要关注最近标记和注意力汇聚的头,即流式头,不需要全注意力。基于这一洞察,我们提出了DuoAttention框架,该框架仅对检索头应用全KV缓存,而对流式头使用轻量级的固定长度KV缓存,从而在不损害其长上下文能力的情况下,减少LLM的解码和预填充内存及延迟。DuoAttention使用一种轻量级的基于优化的算法,通过合成数据准确识别检索头。我们的方法显著减少了长上下文推理内存,MHA模型减少了高达2.55×{2.55} \times,GQA模型减少了高达1.67×{1.67} \times,同时解码速度分别提高了高达2.18×{2.18} \times1.50×{1.50} \times,预填充速度分别提高了高达1.73×{1.73} \times1.63×{1.63} \times,与全注意力相比,精度损失最小。值得注意的是,结合量化技术,DuoAttention使得Llama-3-8B在单个A100 GPU上能够进行330万上下文长度的解码。代码在链接中提供。

1 INTRODUCTION

1 引言

Large language models (LLMs) (Touvron et al. 2023a b; OpenAI, 2023; Black et al., 2022) are at the forefront of the AI revolution, powering advanced applications such as multi-round dialogues (Schulman et al. 2022; Taori et al., 2023; Chiang et al., 2023), long document summarization (Goyal & Durrett, 2020, Zhang et al., 2023a), and tasks involving mixed modalities like visual and video understanding (Liu et al., 2023b; Lin et al., 2023). These applications often require processing extensive numbers of contextual tokens; for instance, summarizing the entire Harry Potter series could involve analyzing approximately one million tokens. The challenge intensifies with visual language models (VLMs),where a single 224×224{224} \times {224} image corresponds to 256 tokens (Liu et al., 2023b), and a three-minute video at 24 FPS generates around 1.1 million tokens.

大型语言模型(LLMs)(Touvron et al. 2023a, b; OpenAI, 2023; Black et al., 2022)处于人工智能革命的前沿,推动了多轮对话(Schulman et al. 2022; Taori et al., 2023; Chiang et al., 2023)、长文档摘要(Goyal & Durrett, 2020, Zhang et al., 2023a)以及涉及混合模态的任务,如视觉和视频理解(Liu et al., 2023b; Lin et al., 2023)等高级应用。这些应用通常需要处理大量上下文标记;例如,总结整个《哈利·波特》系列可能涉及分析约一百万个标记。随着视觉语言模型(VLMs)的出现,这一挑战加剧,其中单个 224×224{224} \times {224} 图像对应于 256 个标记(Liu et al., 2023b),而三分钟的视频以 24 FPS 生成约 110 万个标记。

A critical issue in deploying LLMs in such applications is the long-context inference problem. The full attention mechanism demands that all tokens attend to every previous token for accurate representation, resulting in linearly increasing decoding and quadratically increasing pre-filling latency as the sequence length grows. Additionally, the Key-Value (KV) Cache technique, which stores keys and values from all preceding tokens, causes memory usage to scale linearly with context length. As sequences lengthen, memory is increasingly consumed by the KV cache, placing a significant computational burden on the attention mechanism. For instance, in the Llama-3-8B (Dubey et al. 2024) model architecture, serving with FP16 KV cache for 1 million tokens would require at least 137  GB{137}\mathrm{\;{GB}} of memory-exceeding the capacity of a single 80  GB{80}\mathrm{\;{GB}} GPU. Additionally,the latencies

在这些应用中部署 LLMs 的一个关键问题是长上下文推理问题。完整注意力机制要求所有标记都参与每个先前的标记以进行准确表示,导致随着序列长度的增加,解码延迟线性增加,预填充延迟二次增加。此外,Key-Value(KV)缓存技术存储所有先前标记的键和值,导致内存使用量随上下文长度线性扩展。随着序列变长,KV 缓存消耗的内存越来越多,给注意力机制带来了巨大的计算负担。例如,在 Llama-3-8B(Dubey et al. 2024)模型架构中,使用 FP16 KV 缓存处理 100 万个标记将至少需要 137  GB{137}\mathrm{\;{GB}} 的内存,超过单个 80  GB{80}\mathrm{\;{GB}} GPU 的容量。此外,延迟


*Part of the work done during an internship at NVIDIA.

*在NVIDIA实习期间完成的工作的一部分。


Figure 1: Visualization of attention maps in the Llama-2-7B model for the sentence "The best fruit is orange. What is the best fruit? Orange." shows the distinct roles of retrieval heads (e.g., Layer 15, Head 12) and streaming heads (e.g., Layer 10, Head 1). On the left, retrieval heads capture contextually relevant tokens such as "best," "fruit," and "orange," which are crucial for processing long-context information and, therefore, require a full KV cache. In the middle, streaming heads primarily focus on initial and recent tokens without emphasizing past contextual relevance. On the right, the impact of limiting attention to the sink and recent tokens on long-context passkey retrieval accuracy is shown: modifying retrieval heads severely damages performance, while constraining streaming heads has minimal impacts.

图1:在Llama-2-7B模型中,句子“最好的水果是橙子。最好的水果是什么?橙子。”的注意力图可视化展示了检索头(例如,第15层,第12头)和流式头(例如,第10层,第1头)的不同角色。在左侧,检索头捕获了上下文相关的标记,如“最好”、“水果”和“橙子”,这些标记对于处理长上下文信息至关重要,因此需要完整的KV缓存。在中间,流式头主要关注初始和最近的标记,而不强调过去的上下文相关性。在右侧,显示了将注意力限制在接收器和最近标记上对长上下文通行证检索准确性的影响:修改检索头会严重损害性能,而约束流式头的影响最小。

of pre-filling and decoding with such large contexts are significant, posing substantial challenges to the effective use of LLMs in long-context scenarios.

在如此大的上下文中进行预填充和解码的开销是巨大的,给在长上下文场景中有效使用LLM带来了重大挑战。

Despite numerous efforts to overcome the challenges of attention mechanisms in long-context inference, significant computational and memory issues persist. Architectural modifications, such as Grouped-Query Attention (GQA)(Ainslie et al. 2023), require model pre-training and fail to reduce computational costs. Linear Attention methods (Gu & Dao, 2023; Poli et al. 2023), while less demanding in terms of computation and memory, often underperform in long-context scenarios compared to Transformer models. Approximative attention methods,such as H2O{\mathrm{H}}_{2}\mathrm{O} (Zhang et al. 2023b), StreamingLLM (Xiao et al. 2023b), TOVA (Oren et al. 2024), and FastGen (Ge et al. 2024), often compromise accuracy in long-context applications and are incompatible with essential KV cache optimization techniques like GQA. KV cache quantization (Liu et al., 2024, Hooper et al., 2024), although useful, does not reduce the computation time of the attention mechanism. System-level optimizations, including FlashAttention (Dao et al., 2022; Dao, 2023), FlashDecoding (Hong et al. 2024), and PagedAttention (Kwon et al. 2023), while effective, do not reduce the KV cache size and still require significant computation for extended contexts. These limitations emphasize the need for further advancements to deploy models that handle million-level context lengths.

尽管在克服长上下文推理中注意力机制的挑战方面做出了许多努力,但显著的计算和内存问题仍然存在。架构修改,如分组查询注意力(GQA)(Ainslie等人,2023年),需要模型预训练并且无法降低计算成本。线性注意力方法(Gu & Dao,2023年;Poli等人,2023年)虽然在计算和内存方面要求较低,但在长上下文场景中通常表现不如Transformer模型。近似注意力方法,如H2O{\mathrm{H}}_{2}\mathrm{O}(Zhang等人,2023b),StreamingLLM(Xiao等人,2023b),TOVA(Oren等人,2024年)和FastGen(Ge等人,2024年),在长上下文应用中常常牺牲准确性,并且与GQA等关键的KV缓存优化技术不兼容。KV缓存量化(Liu等人,2024年,Hooper等人,2024年)虽然有用,但不会减少注意力机制的计算时间。系统级优化,包括FlashAttention(Dao等人,2022年;Dao,2023年),FlashDecoding(Hong等人,2024年)和PagedAttention(Kwon等人,2023年),虽然有效,但不会减少KV缓存的大小,并且仍然需要大量计算来处理扩展的上下文。这些限制强调了进一步发展的必要性,以部署能够处理百万级上下文长度的模型。

In this paper, we introduce a key observation that attention heads in LLMs can be categorized into two distinct types: Retrieval Heads (Wu et al. 2024) and Streaming Heads, as shown in Figure 1. Retrieval Heads, which represent only a fraction of the total, are crucial for processing long contexts and require full attention across all tokens. In contrast, the majority of attention heads, termed Streaming Heads, primarily focus on recent tokens and attention sinks (Xiao et al. 2023b), and can operate effectively with a reduced KV cache that includes only recent tokens and attention sinks.

在本文中,我们提出了一项关键观察结果,即大型语言模型(LLMs)中的注意力头可以分为两种不同类型:检索头(Wu et al. 2024)和流式头,如图1所示。检索头仅占总数的少数,但对于处理长上下文至关重要,需要对所有标记进行全面关注。相比之下,大多数注意力头被称为流式头,主要关注最近的标记和注意力汇聚点(Xiao et al. 2023b),并且可以通过包含最近标记和注意力汇聚点的简化KV缓存有效运作。

Building on the dichotomy of retrieval and streaming heads, we propose DuoAttention, a general, straightforward, and easily integrated approach that significantly accelerates both LLM's decoding and pre-filling and reduces memory footprints, particularly in long-context scenarios. The core innovation of DuoAttention is a lightweight, optimization-based procedure that identifies non-compressible retrieval heads using synthetic datasets. Unlike existing methods that rely on attention pattern profiling (Wu et al. 2024; Ge et al. 2024; Tang et al. 2024a), DuoAttention directly measures output deviation resulting from token dropping, achieving higher compression rates and improved deployment efficiency. DuoAttention is designed with simplicity and efficiency in mind: each Transformer layer has two KV caches- a full KV cache for crucial retrieval heads and a constant KV cache for streaming heads, which stores only attention sinks and recent tokens. This design allows DuoAttention to dramatically reduce memory usage and improve decoding speed in models like Llama-2/3 and Mistral,achieving up to 2.55×{2.55} \times for MHA and 1.67×{1.67} \times for GQA models while speeding up decoding by up to 2.18×{2.18} \times and 1.50×{1.50} \times and accelerating pre-filling by up to 1.73×{1.73} \times and 1.63×{1.63} \times for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention.

基于检索头和流式头的二分法,我们提出了DuoAttention,这是一种通用、简单且易于集成的方法,能够显著加速LLM的解码和预填充过程,并减少内存占用,特别是在长上下文场景中。DuoAttention的核心创新是一种轻量级的基于优化的程序,用于识别不可压缩的检索头,使用合成数据集进行操作。与依赖于注意力模式分析的现有方法(Wu et al. 2024; Ge et al. 2024; Tang et al. 2024a)不同,DuoAttention直接测量由于标记丢弃导致的输出偏差,从而实现更高的压缩率和改进的部署效率。DuoAttention的设计注重简单性和效率:每个Transformer层有两个KV缓存——一个用于关键检索头的完整KV缓存和一个用于流式头的常量KV缓存,后者仅存储注意力汇聚点和最近的标记。这种设计使得DuoAttention能够大幅减少内存使用并提高解码速度,例如在Llama-2/3和Mistral模型中,分别实现了MHA模型的2.55×{2.55} \times和GQA模型的1.67×{1.67} \times的压缩率,同时解码速度分别提高了2.18×{2.18} \times1.50×{1.50} \times,预填充速度分别提高了1.73×{1.73} \times1.63×{1.63} \times,与全注意力相比,精度损失最小。

Moreover, DuoAttention is fully compatible with important optimization techniques like GQA and quantization. We show that when combined with 8-bit weight 4-bit KV cache quantization,

此外,DuoAttention 完全兼容重要的优化技术,如 GQA 和量化。我们展示了当与 8 位权重 4 位 KV 缓存量化结合时,

Figure 2: Overview of DuoAttention: (1) In the retrieval head identification phase, we assign a trainable gate value, α\alpha ,to each attention head,which blends the outputs of full attention and streaming attention. The training objective is to optimize these values to minimize the deviation from the full attention model's output, while simultaneously applying a regularization loss to encourage lower gate values. This training phase is efficient, requiring only the gate values to be trainable-leaving all other model parameters frozen-thus allowing it to be completed within several hours on an 8 GPU node. (2) During deployment, these gate values are binarized to classify heads as either retrieval or streaming based on a threshold τ\tau . Retrieval heads,identified by a gate value above the threshold, use full attention, caching the KV pairs for all tokens. In contrast, streaming heads cache only the KV pairs of recent tokens and attention sinks.

图 2:DuoAttention 概述:(1) 在检索头识别阶段,我们为每个注意力头分配一个可训练的门值 α\alpha,该值混合了全注意力和流式注意力的输出。训练目标是优化这些值以最小化与全注意力模型输出的偏差,同时应用正则化损失以鼓励较低的门值。此训练阶段是高效的,只需要门值是可训练的——所有其他模型参数保持冻结——因此允许它在 8 GPU 节点上在几个小时内完成。(2) 在部署期间,这些门值被二值化,根据阈值 τ\tau 将头分类为检索或流式。检索头,由高于阈值的门值识别,使用全注意力,缓存所有令牌的 KV 对。相比之下,流式头仅缓存最近令牌和注意力汇的 KV 对。

DuoAttention enables a Llama-3-8B model to handle up to 3.3 million contextual tokens measured on a single A100 GPU,achieving a 6.4×{6.4} \times capacity increase compared to standard full attention FP16 deployments. DuoAttention paves the way for deploying LLMs in applications requiring million-level context handling.

DuoAttention 使 Llama-3-8B 模型能够在单个 A100 GPU 上处理多达 330 万个上下文令牌,与标准全注意力 FP16 部署相比,实现了 6.4×{6.4} \times 容量增加。DuoAttention 为在需要百万级上下文处理的应用中部署 LLM 铺平了道路。

2 DUOATTENTION

2 DUOATTENTION

2.1 RETRIEVAL AND STREAMING HEADS

2.1 检索头和流式头

Retrieval Heads In Transformer-based LLMs, attention heads exhibit distinct and consistent patterns, reflecting their specialized functionalities (Clark et al., 2019, Xiao et al., 2023b, Wu et al., 2024). Figure 1 visualizes two types of attention heads in the Llama-2-7B-32K-Instruct model using the sentence "The best fruit is orange. What is the best fruit? Orange". The left panel highlights an attention head that emphasizes relevant tokens during decoding; for instance, the first occurrence of "best fruit" is accentuated while decoding the second "best fruit," and the initial "orange" is highlighted when inferring the second "orange." These attention heads, which we term Retrieval Heads, are crucial for context processing as they capture contextually relevant tokens. Compressing the KV cache for retrieval heads would lead to the loss of vital contextual information, and thus they require full attention across all tokens.

基于Transformer的大型语言模型(LLMs)中的注意力头表现出不同的、一致的模式,反映了它们的专业功能(Clark et al., 2019, Xiao et al., 2023b, Wu et al., 2024)。图1使用句子“最好的水果是橙子。最好的水果是什么?橙子”可视化了Llama-2-7B-32K-Instruct模型中的两种注意力头。左侧面板突出显示了一个在解码过程中强调相关标记的注意力头;例如,在解码第二个“最好的水果”时,第一个“最好的水果”被强调,而在推断第二个“橙子”时,初始的“橙子”被突出显示。这些我们称之为检索头的注意力头对于上下文处理至关重要,因为它们捕捉了上下文相关的标记。压缩检索头的KV缓存会导致重要上下文信息的丢失,因此它们需要对所有标记进行全注意力处理。

Streaming Heads In contrast, the attention head depicted in the middle panel of Figure 1 primarily attends to recent tokens and attention sinks (Xiao et al. 2023b), without highlighting earlier relevant tokens in the context. We refer to these as Streaming Heads. Compressing the KV cache for Streaming Heads is feasible because dropping the unattended middle tokens does not significantly alter the attention output. Therefore, streaming heads can be optimized by retaining only the KV states of attention sinks and recent tokens, without compromising the model's ability to manage long contexts.

流式头相比之下,图1中间面板中描述的注意力头主要关注最近的标记和注意力汇聚点(Xiao et al. 2023b),而不强调上下文中较早的相关标记。我们称这些为流式头。压缩流式头的KV缓存是可行的,因为丢弃未被关注的中间标记不会显著改变注意力输出。因此,流式头可以通过仅保留注意力汇聚点和最近标记的KV状态来优化,而不会损害模型管理长上下文的能力。

Impact of Token Pruning on Retrieval and Streaming Heads The right panel of Figure 1 shows a preliminary passkey retrieval experiment, showing that the model's performance drops significantly when the middle tokens in the KV cache of retrieval heads are pruned, i.e., replaced with streaming attention. In contrast, removing the middle tokens for streaming heads has no significant impact on passkey retrieval accuracy. This observation indicates that we can enhance computational efficiency without sacrificing the model's long-context capabilities: By dropping middle tokens for streaming heads while keeping full attention for retrieval heads, we reduce the memory demands of streaming heads to O(1)O\left( 1\right) ,thereby improving the efficiency of processing long contexts.

令牌剪枝对检索和流式处理头的影响 图1的右侧面板展示了一个初步的通行密钥检索实验,显示当检索头的KV缓存中的中间令牌被剪枝时,即被流式注意力替换时,模型的性能显著下降。相比之下,移除流式处理头的中间令牌对通行密钥检索准确性没有显著影响。这一观察表明,我们可以在不牺牲模型长上下文能力的情况下提高计算效率:通过为流式处理头丢弃中间令牌,同时为检索头保持全注意力,我们将流式处理头的内存需求减少到 O(1)O\left( 1\right),从而提高了处理长上下文的效率。

Figure 3: Example from the synthetic dataset used to identify retrieval heads. We embed ten 32-word passkeys within a long text and ask the model to recall these passkeys. Distillation loss is calculated solely on the passkeys.

图3:用于识别检索头的合成数据集中的示例。我们在长文本中嵌入了十个32词的通行密钥,并要求模型回忆这些通行密钥。蒸馏损失仅在通行密钥上计算。

Figure 4: Optimized gate values of four LLMs. Llama-2-7B uses MHA with 32 heads per layer, while Mistral and Llama-3 models use GQA with 8 heads per layer. Retrieval heads have higher scores. MHA models have a lower ratio of retrieval heads compared to GQA models.

图4:四个LLM的优化门值。Llama-2-7B使用每层32个头的MHA,而Mistral和Llama-3模型使用每层8个头的GQA。检索头的得分较高。与GQA模型相比,MHA模型的检索头比例较低。

2.2 Optimization-Based Identification of Retrieval Heads

2.2 基于优化的检索头识别

Definition of Retrieval Heads Section 2.1 qualitatively defines retrieval and streaming heads, but for precise identification, we need a concrete and quantitative definition. In this paper, we define "retrieval heads" as the attention heads that:

检索头的定义 第2.1节定性地定义了检索和流式处理头,但为了精确识别,我们需要一个具体且量化的定义。在本文中,我们将“检索头”定义为那些:

significantly alter model outputs when restricted to recent tokens and attention sinks.

当限制为最近令牌和注意力汇时,显著改变模型输出的注意力头。

We use this criterion to distinguish retrieval heads from streaming heads. Note that this definition differs from existing works (Ge et al., 2024; Wu et al., 2024; Tang et al., 2024a) that rely solely on attention scores to identify retrieval heads, which overlook 1) the end-to-end impact of compressing the KV cache for specific attention heads, 2) the role of value states, and 3) the variability of attention distributions across layers and heads. In contrast, our definition directly measures output deviation, allowing us to identify attention heads crucial for long-context processing, even when they are not apparent in attention scores. We support this argument with ablation studies presented in Section 3.5

我们使用这一标准来区分检索头和流式头。需要注意的是,这一定义与现有工作(Ge et al., 2024; Wu et al., 2024; Tang et al., 2024a)有所不同,后者仅依赖注意力分数来识别检索头,这忽略了1) 压缩KV缓存对特定注意力头的端到端影响,2) 值状态的作用,以及3) 注意力分布在不同层和头之间的可变性。相比之下,我们的定义直接测量输出偏差,使我们能够识别对长上下文处理至关重要的注意力头,即使在注意力分数中不明显时也能识别。我们在第3.5节中通过消融研究支持这一论点。

Optimization-based Identification We employ an optimization-based approach to identify retrieval heads, drawing inspiration from prior work in CNN filter pruning (Liu et al. 2017), as illustrated in Figure 2. First,we assign a gate value αi,j{\alpha }_{i,j} ,to each key-value (KV) head in the LLM. This value intuitively represents the importance of the jj -th KV head in layer ii for processing long-context information. Note that in models using GQA, one KV head can be associated with multiple attention heads, and our method accounts for the KV cache compression of an entire group of attention heads.

基于优化的识别方法 我们采用基于优化的方法来识别检索头,借鉴了CNN滤波器剪枝的先前工作(Liu et al. 2017),如图2所示。首先,我们为LLM中的每个键值(KV)头分配一个门值αi,j{\alpha }_{i,j},该值直观地表示第jj层中的第ii个KV头在处理长上下文信息中的重要性。需要注意的是,在使用GQA的模型中,一个KV头可以与多个注意力头相关联,我们的方法考虑了整个注意力头组的KV缓存压缩。

Our optimization-based identification method directly assesses the impact of compressing the KV cache with only sink and recent tokens for each KV head. We begin by initializing the gate value αi,j[0,1]{\alpha }_{i,j} \in \left\lbrack {0,1}\right\rbrack for each head at 1,assuming that all heads initially serve as retrieval heads. These gate values are then optimized, with the LLM's parameters remaining fixed, limiting the number of trainable parameters to N×HN \times H and preventing the impact to the model’s original abilities.

我们的基于优化的识别方法直接评估仅使用sink和最近令牌压缩每个KV头的KV缓存的影响。我们首先将每个头的门值αi,j[0,1]{\alpha }_{i,j} \in \left\lbrack {0,1}\right\rbrack初始化为1,假设所有头最初都作为检索头。然后优化这些门值,同时保持LLM的参数固定,将可训练参数的数量限制为N×HN \times H,并防止对模型原始能力的影响。

During the forward pass, we combine the outputs of full and streaming attention (which attends only to sink and recent tokens) for each KV head, using the gate value as the mixing weight:

在前向传递过程中,我们使用门值作为混合权重,将每个KV头的完整和流式注意力(仅关注接收和最近的标记)的输出结合起来:

where the attention calculations are defined as:

其中注意力计算定义为:

where Mcausal {\mathbf{M}}_{\text{causal }} is the causal attention mask (a lower triangular matrix),and Mstreaming {\mathbf{M}}_{\text{streaming }} represents a Λ\Lambda -like mask (Han et al. 2023,Xiao et al. 2023b) that attends only to recent and initial tokens.

其中 Mcausal {\mathbf{M}}_{\text{causal }} 是因果注意力掩码(一个下三角矩阵),而 Mstreaming {\mathbf{M}}_{\text{streaming }} 表示一个 Λ\Lambda -like 掩码(Han et al. 2023, Xiao et al. 2023b),它仅关注最近的和初始的标记。

Synthetic Dataset for Identifying Retrieval Heads However, relying solely on natural language modeling objectives is insufficient for identifying retrieval heads because the supervision signal in

用于识别检索头的合成数据集 然而,仅依赖自然语言建模目标不足以识别检索头,因为在

—— 更多内容请到Doc2X翻译查看—— —— For more content, please visit Doc2X for translations ——