NVIDIA提出Hymba网络,Mamba+Attenation解决transformer的长序列低效难题

283 阅读9分钟

背景:transformer与mamba架构的缺陷

Transformers 成为大型语言模型(Large Language Models,LLMs)的首选架构。Transformers 不仅在处理复杂的语言任务上表现出色,还具备优越的并行处理能力和通过键值(Key-Value,KV)缓存实现的长期记忆能力。然而,随着输入序列长度的增加,Transformers 的计算复杂度和内存需求呈二次方增长。

为了应对 Transformers 在处理长序列时效率低下的问题,研究人员开发了状态空间模型(State Space Models,SSMs)如 Mamba 和 Mamba-2。这些模型的计算复杂度与序列长度呈线性关系,使得它们能更加高效的处理长序列。然而,SSM 模型效率高但记忆能力有限,任务性能较差。

为了充分结合 Transformers 和 SSMs 的优势,同时克服它们各自的不足,NVIDIA 提出了 Hymba 架构。

image.png

论文信息

Hymba简介

Himba 的核心创新在于将注意力头和 SSM 头集成在同一层中,实现了对相同输入的并行和互补处理。这种混合头方法综合了注意力机制的高分辨率记忆能力与 SSM 的高效上下文总结能力,从而增强了模型在处理各种信息流和记忆访问模式时的能力。

Hymba 还引入了一组可学习的元 tokens。这些元 tokens 被预先添加到输入序列的前面,并与所有后续 tokens 进行交互,即使在使用滑动窗口注意力时也参与计算。这些元 tokens 充当了世界知识的压缩表示,能够引导注意力机制更加关注对任务性能有意义的 tokens。 元 tokens 通过有效分配注意力权重,避免了模型过度关注无关信息。

深入理解混合头模块

为了更好地理解 Hymba 的混合头模块,我们可以将其类比为人脑的记忆功能。混合头模块是 Hymba 模型的基本单元,负责处理输入数据并生成输出。在混合头模块中,存在并行的注意力头和 SSM 头,分别处理相同的输入信息。

image.png

如上图所示,在模块的左侧是输入的 tokens。这些 tokens 是模型接收的原始数据。此外,序列的开头还添加了可学习的元 tokens,稍后将详细解释它们的具体作用。接下来,一个线性输入投影层将输入数据进行投影转换。

在混合头模块的核心部分,注意力头和 SSM 头并行工作。来自输入投影层的输出被分配给各个头,每个头接收它所需的输入。例如,对于注意力头,我们会得到 K(键)、V(值)和 Q(查询)矩阵,这些矩阵通过线性投影层中的专用权重矩阵计算得出。这些矩阵用于计算注意力分数,决定模型关注哪些部分的信息。与此同时,SSM 头也会接收它们所需的输入,并通过状态空间模型处理这些信息,进行上下文的高效总结。

各个头在接收到输入后,会独立地处理这些输入并生成输出。然而,研究人员发现 SSM 头的输出幅度始终大于注意力头的输出。这意味着直接将它们的输出结合起来可能会导致模型不稳定。为了解决这个问题,Himba 首先对两种类型的头的输出进行归一化处理,确保它们在同一尺度上。然后,将它们的输出进行平均,最后通过一个线性输出投影层将组合后的结果传递给模型的下一个组件。这一过程确保了不同类型头的输出能够有效融合,提升模型的整体稳定性和性能。

Hymba vs 已有方法

已有的混合模型通常采用将注意力头和 SSM 头顺序堆叠的方式。这种方法存在一个问题:当某一层类型不适合特定任务时,整个模块的性能可能会受到影响,需要后续层来补偿。

基于 Transformer 中多头注意力机制,Himba 采用了并行头的设计。在 Transformer 中,不同的注意力头承担不同的角色,专注于处理不同的上下文信息,这种设计有效提升了模型的表现。Himba 的混合头模块借鉴了这一思想,通过将注意力头和 SSM 头并行结合,使得每个头能够以不同的方式处理相同的信息,从而充分发挥两种机制的优势。

Hymba 与人脑记忆过程的相似性

image.png

为了更直观地理解 Hymba 的混合头模块,我们可以将其类比为人脑的记忆功能。在人脑中,不同类型的记忆过程协同工作来存储和回忆信息。例如,可以将 Hymba 中的注意力头比作快照记忆。这些记忆是我们对特定时刻或事件的详细回忆,使我们在需要时能够记住精确的细节。Hymba 中的注意力头也以类似的方式工作,能够从输入序列中提供特定信息的高分辨率回忆。

另一方面,Hymba 中的状态空间模型(SSM)头更像是渐淡的记忆。这些记忆帮助我们总结过去事件的整体要旨,而不保留所有细节。SSM 头能够高效地总结更广泛的上下文,确保模型能够处理长序列,而无需承担维持大量细节造成的计算负担。

通过在同一层中结合这两种类型的记忆,Hymba 的混合头模块模仿了人脑在详细回忆与高效总结之间的平衡。这种设计使模型能够更有效地处理各种类型的信息流和记忆访问模式,提升了模型在不同任务中的表现。

深入理解元 Tokens

Hymba 中的元 tokens 类似于人脑中的元记忆。元记忆帮助我们识别在记忆中找到所需信息的位置。类似地,元 tokens 引导模型关注相关信息。这也有助于减轻注意力流失的问题,即某些 tokens(通常称为“sink tokens”)获得过高的注意力权重,从而保证模型能够更专注于重要的信息。

image.png

在 Hymba 模型的最后一层中,不同任务领域的提示会激活不同的元 tokens。具体来说,代码领域的提示激活了一组不同的元 tokens,而数学和文章领域的提示则激活了另一组元 tokens。不同的元 tokens 封装了不同的世界知识,可以用来引导注意力机制更加关注与任务相关的信息。

此外,元 tokens 还在缓解注意力流失方面发挥了重要作用。注意力流失指的是某些无关紧要的 tokens获得过高的注意力权重,导致模型无法有效关注重要信息。通过引入元 tokens,Hymba 模型能够重新分配注意力权重,使得模型更加专注于对任务性能有意义的 tokens。在推理时,元 tokens 是固定的,并出现在任何输入序列的开头,因此它们的计算可以离线完成。

Hymba 的整体架构

Himba 的整体架构是通过堆叠多个 Hymba 块构建而成的。每个 Hymba 块由以下几个部分组成:

image.png

  1. 归一化层:用于规范化输入数据,提升训练稳定性和收敛速度。
  2. 混合头模块:集成了注意力头和 SSM 头,实现对相同输入的并行和互补处理。
  3. 另一个归一化层:进一步规范化混合头模块的输出。
  4. 前馈网络(Feedforward Network,FFN) :用于对混合头模块的输出进行非线性变换,增强模型的表达能力。

在这些 Hymba 块中,只有第一个、中间和最后一个块使用全注意力(full attention),即处理所有的 tokens。这种设计确保了模型在处理输入序列时能够兼顾全局和局部信息。所有其他块则采用滑动窗口注意力(Sliding Window Attention,SWA)的技术。模型在滑动窗口内关注局部上下文,减少了注意力所需的 KV 缓存大小,同时由于三个全注意力块的存在,仍能获得全局注意力。这种设计在提高模型效率的同时,保证了其在处理长序列和复杂任务时的性能。

此外,对于滑动窗口注意力块,Himba 采用了跨层键值(Key-Value,KV)缓存共享策略。具体来说,相邻块之间共享它们的 KV 缓存,而不是每个块维护独立的 KV 缓存。这种共享策略有效减少了 KV 缓存的冗余和内存使用,提高了模型的整体效率。这一策略源于最近的研究成果:相邻层中的 KV 缓存具有高度相似性,从而进一步提升了模型的性能和资源利用率。

数学表示

下面我们将从数学公式的角度,详细讲解混合头模块的运算过程、每个部分的输入与输出。

1. 输入序列与元 Tokens

首先,模型接收一个输入序列 X=[x1,x2,,xn]X = [x_1, x_2, \dots, x_n]。为了增强模型的记忆能力,Himba 在输入序列前添加了一组可学习的元 tokens R=[r1,r2,,rm]R = [r_1, r_2, \dots, r_m]。组合后的输入序列记作:

X~=[R,X]=[r1,r2,,rm,x1,x2,,xn]\tilde{X} = [R, X] = [r_1, r_2, \dots, r_m, x_1, x_2, \dots, x_n]
2. 输入投影层(Input Projection Layer)

接下来,混合头模块通过一个线性输入投影层将输入序列 X~\tilde{X} 转换为适合注意力头和 SSM 头处理的格式。输入投影层的权重矩阵记作 Win_proj=[WQ,WK,WV,WSSM,WG]W_{in\_proj} = [W_Q, W_K, W_V, W_{SSM}, W_G],其中:

  • WQ,WK,WVW_Q, W_K, W_V:用于生成注意力机制中的查询(Query)、键(Key)和值(Value)矩阵。
  • WSSM,WGW_{SSM}, W_G:用于生成 SSM 头的输入特征和门控信号。

通过输入投影,得到以下输出:

Q=WQX~,K=WKX~,V=WVX~Q = W_Q \tilde{X}, \quad K = W_K \tilde{X}, \quad V = W_V \tilde{X}
Xssm=WSSMX~,G=WGX~X_{ssm} = W_{SSM} \tilde{X}, \quad G = W_G \tilde{X}
3. 注意力头(Attention Heads)

注意力头的输出 YattnY_{attn}Yattn​ 通过以下公式计算:

Yattn=softmax(QKT)V=MattnX~(1)Y_{attn} = \text{softmax}(QK^T) V = M_{attn} \tilde{X} \quad (1)

其中:

  • Mattn=softmax(QKT)VM_{attn} = \text{softmax}(QK^T) V 是注意力权重矩阵。
  • QKTQK^T 计算查询与键的相似度。
  • softmax\text{softmax} 函数将相似度转换为概率分布。
4. 状态空间模型头(SSM Heads)

SSM 头的输出 YssmY_{ssm} 通过以下公式计算:

αi,j=Ci(k=j+1iexp(AΔk))BjΔj(2)\alpha_{i,j} = C_i \left( \prod_{k=j+1}^i \exp(A \Delta_k) \right) B_j \Delta_j \quad (2)
Yssm=Gα(A,B,C,Δ)WSSMX~=MssmX~Y_{ssm} = G \odot \alpha(A, B, C, \Delta) W_{SSM} \tilde{X} = M_{ssm} \tilde{X}

其中:

  • αi,j\alpha_{i,j} 是 SSM 头的权重系数。
  • AA 是可学习矩阵,控制状态空间的动态。
  • B,C,ΔB, C, \Delta 是其他可学习参数,定义了 SSM 的特性。
  • GG 是输出门,通过 G=WGX~G = W_G \tilde{X} 计算得到,用于调节输出。
  • ⊙\odot⊙ 表示元素级别的逐位相乘。
5. 输出融合与归一化

研究人员发现,SSM 头的输出幅度 YssmY_{ssm} 通常大于注意力头的输出 YattnY_{attn}。为了确保两者的有效融合,Himba 采用了以下步骤:

归一化与重新缩放

norm(Yattn)=归一化(Yattn)norm(Yssm)=归一化(Yssm)\text{norm}(Y_{attn}) = \text{归一化}(Y_{attn}) \\ \text{norm}(Y_{ssm}) = \text{归一化}(Y_{ssm})

使用可学习向量 β1\beta_1β2\beta_2 对归一化后的输出进行重新缩放。

输出平均

Y=Wout_proj(β1norm(Yattn)+β2norm(Yssm))Y = W_{out\_proj} (\beta_1 \cdot \text{norm}(Y_{attn}) + \beta_2 \cdot \text{norm}(Y_{ssm}))

最终,通过一个线性输出投影层 Wout_projW_{out\_proj} 将两种头的输出进行融合,生成最终的输出 YY

6. 总结运算流程

综合上述步骤,Hymba 的混合头模块的运算流程如下:

  1. 输入处理

    • 原始输入序列 XX 添加元 tokens RR 形成 X~\tilde{X}
  2. 输入投影

    • 通过 Win_projW_{in\_proj}X~\tilde{X} 投影生成Q,K,V,Xssm,GQ, K, V, X_{ssm}, G
  3. 注意力计算

    • 计算注意力头的输出 YattnY_{attn}
  4. SSM 计算

    • 计算 SSM 头的输出 YssmY_{ssm}
  5. 输出融合

    • YattnY_{attn}YssmY_{ssm} 进行归一化和重新缩放。
    • 将两者的输出进行平均,并通过 Wout_projW_{out\_proj} 生成最终输出 YY

通过这种并行处理方式,Hymba 的混合头模块能够同时利用注意力机制的高分辨率记忆能力和 SSM 的高效上下文总结能力,从而提升模型的表现。

实验效果

image.png

上表比较了 Hymba 与SOTA的小型语言模型。除了 Llama-3.2–3B,所有模型的参数都少于 2B,其中 Hymba 模型有 1.5 B个参数。在右侧的列显示了平均性能, Hymba 在只使用 1.5 万亿 tokens来进行训练的前提下,超越了所有其他模型。