DuoAttention:高效处理长上下文推理的 AI 框架,让 LLMs 如虎添翼!

287 阅读6分钟

❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦


🚀 快速阅读

  1. DuoAttention 通过区分“检索头”和“流式头”两种注意力头,优化模型的内存使用和计算速度。
  2. DuoAttention 能在保持模型准确性的同时,减少内存消耗和提高解码及预填充的速度。
  3. 结合量化技术,DuoAttention 能在单个 GPU 上实现高达 330 万 token 的上下文推理。

正文(附运行示例)

DuoAttention 是什么

duoattention_method1.jpg

DuoAttention 是新型的框架,由 MIT 韩松团队提出,用在提高大型语言模型(LLMs)在处理长上下文时的推理效率。基于区分“检索头”和“流式头”两种注意力头,优化模型的内存使用和计算速度。检索头负责处理长距离依赖,需要完整的键值(KV)缓存,流式头关注最近 token 和注意力汇聚点,只需固定长度的 KV 缓存。两种注意力头让 DuoAttention 在保持模型准确性的同时,减少内存消耗和提高解码及预填充的速度。结合量化技术,DuoAttention 能在单个 GPU 上实现高达 330 万 token 的上下文推理,是处理长文本信息的有效方案。

DuoAttention 的主要功能

  • 提高长上下文推理效率:基于优化大型语言模型(LLMs)的注意力机制,DuoAttention 显著提升模型处理长上下文数据的能力。
  • 减少内存消耗:区分需要完整 KV 缓存的检索头和只需固定长度 KV 缓存的流式头,减少模型运行时的内存占用。
  • 加速解码和预填充过程:DuoAttention 优化模型的解码速度和预填充(Pre-filling)速度,提高 LLMs 的响应时间和处理效率至关重要。
  • 保持模型准确性:在减少内存消耗和提高效率的同时,DuoAttention 能保持模型在处理长短上下文任务时的准确性。

DuoAttention 的技术原理

  • 注意力头的区分:DuoAttention 将 LLMs 中的注意力头分为检索头和流式头。检索头负责捕捉上下文中的关键信息,对所有 token 进行完整注意力处理;流式头主要处理近期 token 和注意力汇聚点,不需要存储全部历史 KV 状态。
  • 检索头的 KV 缓存优化:为检索头保留完整的 KV 缓存,确保能捕捉到长距离依赖信息。
  • 流式头的轻量级 KV 缓存:流式头用固定长度的 KV 缓存,减少对内存的需求,支持模型高效处理长序列数据。
  • 检索头的自动识别:DuoAttention 用基于优化的算法和合成数据集训练模型,自动识别出哪些头是检索头,在推理时为分配适当的 KV 缓存策略。
  • 合成数据集:设计合成数据集和密码召回任务,DuoAttention 能确定哪些注意力头在保留或丢弃 KV 缓存后对模型输出有显著影响,优化模型的长上下文处理能力。

如何运行 DuoAttention

环境设置

训练和评估环境
conda create -yn duo python=3.10
conda activate duo

conda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia

pip install transformers accelerate sentencepiece datasets wandb accelerate sentencepiece datasets wandb zstandard matplotlib huggingface_hub
pip install tensor_parallel

pip install ninja packaging
pip install flash-attn --no-build-isolation

# LongBench评估
pip install seaborn rouge_score einops pandas

pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/

# 安装DuoAttention
pip install -e .

# 安装Block Sparse Streaming Attention
git clone git@github.com:mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
python setup.py install
演示环境
conda create -yn duo_demo python=3.10
conda activate duo_demo

# 安装DuoAttention
pip install -e .

conda install -y git
conda install -y nvidia/label/cuda-12.4.0::cuda-toolkit
conda install -y nvidia::cuda-cudart-dev

# 安装QServe
git clone git@github.com:mit-han-lab/qserve.git
cd qserve
pip install -e .
pip install ninja packaging
pip install flash-attn==2.4.1 --no-build-isolation
cd kernels
python setup.py install

# 安装FlashInfer
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
pip install tensor_parallel

数据集

下载数据集:

mkdir -p datasets
cd datasets

wget https://huggingface.co/datasets/togethercomputer/Long-Data-Collections/resolve/main/fine-tune/booksum.jsonl.zst

模型

下载 DuoAttention 支持的模型:

mkdir -p models
cd models

# DuoAttention目前支持的评估模型
huggingface-cli download togethercomputer/Llama-2-7B-32K-Instruct --local-dir Llama-2-7B-32K-Instruct
huggingface-cli download gradientai/Llama-3-8B-Instruct-Gradient-1048k --local-dir Llama-3-8B-Instruct-Gradient-1048k
huggingface-cli download gradientai/Llama-3-8B-Instruct-Gradient-4194k --local-dir Llama-3-8B-Instruct-Gradient-4194k
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2 --local-dir Mistral-7B-Instruct-v0.2
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.3 --local-dir Mistral-7B-Instruct-v0.3

# 使用SmoothQuant和QServe的W8A8KV4模型
huggingface-cli download mit-han-lab/Llama-3-8B-Instruct-Gradient-1048k-w8a8kv4-per-channel --local-dir Llama-3-8B-Instruct-Gradient-1048k-w8a8kv4-per-channel
huggingface-cli download mit-han-lab/Llama-3-8B-Instruct-Gradient-4194k-w8a8kv4-per-channel --local-dir Llama-3-8B-Instruct-Gradient-4194k-w8a8kv4-per-channel

快速开始 DuoAttention

我们提供了一个简单的单点击 patch,用于在 HuggingFace 模型上启用 DuoAttention 优化,包括 Llama 和 Mistral。attn_patterns目录中提供了五个长上下文模型的预训练检索头模式:Llama-2-7B-32K-InstructLlama-3-8B-Instruct-Gradient-1048kLlama-3-8B-Instruct-Gradient-4194kMistral-7B-Instruct-v0.2Mistral-7B-Instruct-v0.3Meta-Llama-3.1-8B-Instruct。如果您想训练自己的检索头模式,可以使用 scripts 目录中提供的训练脚本。以下是如何在Llama-3-8B-Instruct-Gradient-1048k模型上启用 DuoAttention 的示例。

from duo_attn.utils import load_attn_pattern, sparsify_attention_heads
from duo_attn.patch import enable_duo_attention_eval
import transformers
import torch

# 加载模型
model = transformers.AutoModelForCausalLM.from_pretrained(
    "models/Llama-3-8B-Instruct-Gradient-1048k",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="eager",
)

# 加载注意力模式
attn_heads, sink_size, recent_size = load_attn_pattern(
    "attn_patterns/Llama-3-8B-Instruct-Gradient-1048k/lr=0.02-reg=0.05-ctx=1000_32000-multi_passkey10"
)

# 稀疏化注意力头
attn_heads, sparsity = sparsify_attention_heads(attn_heads, sparsity=0.5)

# 启用DuoAttention
enable_duo_attention_eval(
    model,
    attn_heads,
    sink_size=64,
    recent_size=256,
)

# 将模型移至GPU
model = model.cuda()

# 准备进行推理!

演示

设置环境后,您可以运行以下脚本以在Llama-3-8B-Instruct-Gradient-4194k模型上执行 W4A8KV4 与 DuoAttention 的演示。该演示旨在在单个 A100 GPU 上运行,并支持高达 330 万个 token 的上下文长度。

bash scripts/run_demo.sh

资源


❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦