状态空间模型在自然语言处理中的演进与应用

45 阅读12分钟

理解状态空间模型

状态空间模型(SSMs)使用一阶微分方程来表示动态系统,为分析和控制随时间变化的状态提供了结构化框架。

考虑道路上行驶的汽车作为一个动态系统。当我们向该系统施加特定输入(如踩油门)时,会改变汽车的当前状态(例如发动机燃烧的油量),从而导致汽车以特定速度移动。

状态空间模型还建模"跳跃连接",表示输入对输出的直接影响。在汽车的具体情况下,这种直接馈通(D)为零,但我们在模型中保留它,因为系统通常可以(并且确实)具有直接的输入到输出依赖关系。

状态空间表示(SSR)的数学方程为:

  • 新状态x(t)的方程:x(t) = A·x(t-1) + B·u(t)
  • 输出y(t)的方程:y(t) = C·x(t) + D·u(t)

这两个方程形成了我们系统的状态空间表示。通过向动态和观测方程引入噪声项,我们可以对状态变量和输入之间的概率依赖关系进行建模。

自然语言处理中的状态空间模型

状态空间模型在时间序列分析中早已确立,作为可训练的序列模型已被使用数十年。大约在2020年,它们高效处理长序列的能力推动了将其适应自然语言处理(NLP)的重大进展。

将SSMs应用于自然语言处理将输入重新定义为token,将状态重新定义为上下文表示,将输出重新定义为预测的下一个token。

HiPPO:具有最优多项式投影的循环记忆

序列模型面临的主要挑战是捕获长序列中相距较远的两个输入之间的依赖关系。

Albert Gu及其在斯坦福大学的同事试图通过引入HiPPO("高阶多项式投影算子"的缩写)来解决这个问题。这个数学框架旨在将历史信息压缩成固定大小的表示。

HiPPO通过构建一组在特定加权函数下数学正交的多项式基来工作。加权函数w(t)使用两种变体之一来衡量历史信息的重要性:

  1. 变换HiPPO矩阵变体:变换矩阵优先处理最新输入,并随时间不断改变系统响应。序列历史中存储的信息重要性随时间衰减。

  2. 平稳HiPPO矩阵变体:平稳矩阵是时不变的,并以一致的重要性考虑所有过去数据。

Gu及其同事将这两种变体应用于三种不同的多项式族,分别称为Leg、Lag和Cheb。Leg、Lag和Cheb之间的区别在于信息保留量,这由与每组多项式相关的加权函数w(t)的变化及其正交性质决定。

HiPPO矩阵是通过推导在实时将输入信号投影到指定多项式基上的微分算子获得的。这些算子确保状态的正交性,同时保留定义的加权函数。

结合循环、卷积和连续时间模型的线性状态空间层

HiPPO的发明者与其他斯坦福研究人员合作开发了使用HiPPO框架的结构化状态空间序列模型。该模型在将SSMs应用于序列建模任务方面取得了重大进展。

根据作者的说法,理想的序列建模算法应具备以下能力:

  1. 可并行化训练,如卷积神经网络(CNNs)所实现
  2. 有状态推理,如循环神经网络(RNNs)所提供
  3. 时间尺度适应,如神经微分方程(NDEs)所实现

除了这些特性外,模型还应能够以计算高效的方式处理长程依赖关系。

离散化

我们可以使用数值方法对连续SSR方程进行离散化。

将连续信号转换为离散信号的过程称为"离散化"。我们用来测量速度的时间间隔称为时间尺度Δt,也称为"步长"或"离散化参数"。

在《结合循环、卷积和连续时间模型与线性状态空间层》中,作者探索了几种离散化状态空间模型的方法,最终选择了广义双线性变换(GBT),它有效地平衡了精度(通过避免过采样)和稳定性(通过避免欠采样)。

离散状态方程在GBT下给出为: x(t) = (I - Δt/2·A)⁻¹·[(I + Δt/2·A)·x(t-1) + Δt·B·u(t)]

广义双线性变换应用中的一个关键决策是参数α的选择,它控制保留连续时间系统特性与确保离散域稳定性之间的平衡。作者选择α=0.5,因为它平衡了精度和数值稳定性。

现在我们有了SSR方程的离散版本,我们可以将它们应用于自然语言生成任务,其中:

  • u(t)是我们馈入模型的输入token
  • x(t)是上下文,即序列至今历史的表示
  • y(t)是输出,预测的下一个token

SSMs作为序列模型的三大支柱

可并行化训练

为了应对这一挑战,作者引入了SSMs的卷积表示,这使得这些模型能够像CNNs和Transformers一样并行处理序列。

作者的想法是将SSM表示为具有特定核k的卷积操作,该核源自状态空间参数,使模型能够高效计算长序列上的输出。

有状态推理

由于其状态方程,SSMs实现了有状态推理。它们固有地维护包含序列上下文的状态,使它们比基于Transformer的模型更具计算效率。

为了处理长程依赖关系,《结合循环、卷积和连续时间模型与线性状态空间层》的作者使用HiPPO-LegS(HiPPO-Leg的平稳形式)公式来参数化A。

时间尺度适应

时间尺度适应是指序列模型捕获输入序列不同部分中输入token的依赖关系的能力。

模型捕获序列内依赖关系的能力取决于其上下文表示。SSMs将上下文表示为矩阵A。因此,SSM基于新输入通过状态方程更新状态的能力使模型能够适应序列内的上下文依赖关系,使其能够处理长程和短程依赖关系。

线性状态空间层(LSSLs)

Gu及其同事在论文中介绍了线性状态空间层(LSSL),该层利用了状态空间表示方程的离散循环和卷积形式。该层被集成到深度学习架构中,以引入对长程依赖关系和结构化序列表示的高效处理。

为了拥有计算高效的模型,我们似乎需要卷积和循环表示的属性。Gu及其同事设计了一种"两全其美"的方法,在训练期间使用卷积表示,在推理期间使用循环表示。

在其论文中,Gu及其合作者将LSSL架构描述为"涉及堆叠LSSL层的深度神经网络,这些层与归一化层和残差连接相连"。类似于Transformer架构中的注意力层,每个LSSL层前面都有一个归一化层,后面跟着一个GeLU激活函数。

使用状态结构化空间高效建模长序列

LSSL模型在序列数据上表现令人印象深刻,但由于计算复杂性和内存瓶颈而未得到广泛采用。

在论文《使用状态结构化空间高效建模长序列》中,Gu与密切合作者Karan Goel和Christopher Ré一起改进了LSSL,以降低训练过程的计算复杂性和提高准确性。

状态矩阵A的改进

在LSSL中,状态乘以矩阵A以产生状态的更新版本。矩阵A用于乘法的最计算高效形式将是对角矩阵。不幸的是,HiPPO矩阵不能重新形成为对角矩阵,因为它没有完整的特征向量集。

然而,作者能够将矩阵分解为对角加低秩分解(DPLR)。对角矩阵仅在主对角线上有非零条目,这使得乘法过程更高效,每个向量元素只需要一次乘法。低秩矩阵可以表示为两个小得多的矩阵的乘积。

原始LSSL架构需要O(N²L)操作,其中N是状态维度,L是序列长度。将矩阵A转换为其对角加低秩(DPLR)形式后,循环和卷积形式的计算复杂度都降低了:

  • 对于循环形式,DLPR形式只有O(NL)矩阵-向量乘法
  • 对于卷积形式,卷积核减少到只需要O(N log L + L log L)操作

训练实现的改进

在解决LSSL的计算复杂性之后,作者发现了另一个重大改进,即使矩阵A(部分)可学习。在LSSL中,矩阵是固定的,在训练过程中不更新。相反,矩阵B和C负责SSM块的更新和可学习性。

保持矩阵A固定确保计算效率,但限制了模型捕获序列中复杂动态和潜在模式的能力。完全可学习的矩阵A提供了适应任意动态的灵活性。然而,它带来了权衡:更多参数需要优化,训练速度更慢,以及推理期间更高的计算成本。

为了平衡这些竞争需求,修改后的LSSL——被称为S4——采用了部分可学习的A。通过保持A的DPLR结构,模型保留了计算效率,而引入可学习参数增强了其捕获更丰富的、领域特定行为的能力。

此外,《使用状态结构化空间高效建模长序列》介绍了实现双向状态空间模型的技术。这些模型可以向前和向后两个方向处理序列,捕获来自过去和未来上下文的依赖关系。

用于序列建模的简化状态空间层

在《用于序列建模的简化状态空间层》中,Jimmy Smith、Andrew Warrington和Scott Linderman提出了对S4架构的多项改进,以在保持相同计算复杂性的同时提高性能。

虽然S4相对于原始LSSL的改进主要集中于降低模型的计算复杂性,但S5旨在简化架构,使其更高效、更易于实现,同时保持或提高性能。

使用并行关联扫描

并行扫描,也称为并行关联扫描,是一种允许通过预计算序列中每个位置的累积操作(在这种情况下是乘积)的算法,以便可以在处理步骤中选择它们,而不是一次处理一个。

使用并行关联扫描,Smith及其同事能够并行化循环SSMs的训练过程,消除了使用卷积表示的需要。

因此,S5层仅在时域中运行,而不是同时拥有卷积和频域。这是一个重要的改进,因为它允许每层的时间复杂度为O(N log L)而不是O(NL),利用序列长度上的并行计算同时减少内存开销。

允许多输入多输出

LSSL和S4是单输入单输出(SISO)模型。允许多输入多输出(MIMO)在计算上不可行,因为LSSL和S4内部的计算是在一次只有一个输入的假设下设计的。

Smith及其合作者离散化了MIMO SSM方程,而不是SISO SSM方程。使用相同的SSR方程,他们扩展了离散化过程以处理m维输入和n维输出。假设状态有N维,这一变化使B成为N×m矩阵而不是N×1,C成为n×N矩阵而不是1×N。

S5对MIMO的支持使其能够处理多维数据,例如多变量和多通道时间序列数据,同时处理多个序列,并产生多个输出。这通过允许多个序列同时处理而不是拥有m个SSM副本来减少计算开销。

对角化参数化

正如我们上面讨论的,HiPPO-LegS不能对角化。然而,并行扫描方法需要对角矩阵A。通过实验,Smith及其同事发现他们可以将HiPPO-LegS矩阵表示为正规加低秩(NLPR)矩阵,其中正规分量被称为HiPPO-N,可以对角化。

他们证明,通过证明HiPPO-N和HiPPO-LegS产生相同的动态,移除低秩项并初始化HiPPO-N矩阵具有相似的结果。(论文附录中给出了证明。)然而,如果他们使用来自DPLR近似的对角矩阵,该近似将产生与原始结构非常不同的动态。

使用HiPPO-N矩阵的对角化版本通过消除将HiPPO-LegS矩阵转换为其DPLR近似的需要,降低了模型的计算复杂性。

类似于使用结构化参数化矩阵A降低了计算开销,S5使用矩阵B和C的低秩表示,进一步减少了参数数量。

结论与展望

状态空间模型(SSMs)作为序列到序列模型的演进突出了它们在NLP领域日益增长的重要性,特别是对于需要建模长期依赖关系的任务。诸如LSSL、S4和S5等创新通过提高计算效率、可扩展性和表达能力推动了该领域的发展。

尽管S5模型取得了进步,但它仍然缺乏上下文感知能力。S5可以在时域中高效训练和推理,并保留长程依赖关系的信息,但它不像Transformers通过注意力机制那样明确过滤或关注序列的特定部分。

因此,关键下一步是将一种机制纳入SSMs,使它们能够专注于状态的最相关部分,而不是统一处理整个状态。这就是Mamba模型架构解决的问题,我们将在本系列即将到来的第二部分中探讨。