arxiv.org/abs/2012.07…
稀疏注意力机制(ProbSparse Self-attention)
Efficient Self-attention Mechanism
经典的自注意力机制(Vaswani et al. 2017)是基于三元组输入定义的,即:查询(query)、键(key)和值(value),它执行缩放点积,如公式
A ( Q , K , V ) = Softmax ( Q K T / d ) V \mathcal{A}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax}(\mathbf{Q}\mathbf{K}^T/\sqrt{d})\mathbf{V} A ( Q , K , V ) = Softmax ( Q K T / d ) V
所示,其中 Q ∈ R L q × d \mathbf{Q} \in \mathbb{R}^{L_q \times d} Q ∈ R L q × d 、K ∈ R L k × d \mathbf{K} \in \mathbb{R}^{L_k \times d} K ∈ R L k × d 和 V ∈ R L v × d \mathbf{V} \in \mathbb{R}^{L_v \times d} V ∈ R L v × d ,而 d d d 是输入维度。
为了进一步讨论自注意力机制,让 q i , k i , v i \mathbf{q}_i, \mathbf{k}_i, \mathbf{v}_i q i , k i , v i 分别代表 Q , K , V \mathbf{Q}, \mathbf{K}, \mathbf{V} Q , K , V 中的第 i i i 行。
根据 (Tsai et al. 2019) 的定义,第 i i i 个查询的注意力被定义为概率形式下的核平滑:
A ( q i , K , V ) = ∑ j k ( q i , k j ) ∑ l k ( q i , k l ) v j = E p ( k j ∣ q i ) [ v j ] ( 1 ) \mathcal{A}(\mathbf{q}_i, \mathbf{K}, \mathbf{V}) = \sum_{j} \frac{k(\mathbf{q}_i, \mathbf{k}_j)}{\sum_{l} k(\mathbf{q}_i, \mathbf{k}_l)} \mathbf{v}_j = \mathbb{E}_{p(\mathbf{k}_j|\mathbf{q}_i)}[\mathbf{v}_j](1) A ( q i , K , V ) = j ∑ ∑ l k ( q i , k l ) k ( q i , k j ) v j = E p ( k j ∣ q i ) [ v j ] ( 1 )
其中 p ( k j ∣ q i ) = k ( q i , k j ) / ∑ l k ( q i , k l ) p(\mathbf{k}_j|\mathbf{q}_i) = k(\mathbf{q}_i, \mathbf{k}_j)/\sum_{l} k(\mathbf{q}_i, \mathbf{k}_l) p ( k j ∣ q i ) = k ( q i , k j ) / ∑ l k ( q i , k l ) ,且 k ( q i , k j ) k(\mathbf{q}_i, \mathbf{k}_j) k ( q i , k j ) 选择非对称指数核 exp ( q i k j T / d ) \text{exp}(\mathbf{q}_i \mathbf{k}_j^T/\sqrt{d}) exp ( q i k j T / d ) 。
自注意力结合了values,并基于计算概率 p ( k j ∣ q i ) p(\mathbf{k}_j|\mathbf{q}_i) p ( k j ∣ q i ) 获得输出。它需要进行二次时间的点积计算和 O ( L q L k ) \mathcal{O}(L_q L_k) O ( L q L k ) 的内存使用,这是在提升预测能力时的主要缺陷。
一些先前的尝试揭示了自注意力概率分布可能具有稀疏性,并且他们设计了“选择性”计数策略来计算所有 p ( k j ∣ q i ) p(\mathbf{k}_j|\mathbf{q}_i) p ( k j ∣ q i ) ,而不显著影响性能。稀疏Transformer(Child et al. 2019)结合了行输出和列输入,其中稀疏性来自于分离的空间相关性。LogSparse Transformer(Li et al. 2019)注意到自注意力中的周期性模式,并通过指数步长强制每个单元格关注其前一个。Longformer(Beltagy, Peters, and Cohan 2020)扩展了先前的两个工作,以更复杂的稀疏配置。然而,它们受限于从以下启发式方法进行的理论分析,并以相同的策略处理每个多头自注意力,限制了它们的进一步改进。
为了激励我们的方法,我们首先对经典自注意力的学习注意力模式进行定性评估。“稀疏”自注意力得分形成长尾分布,即,少数点积对对主要注意力贡献,而其他的产生微不足道的注意力。然后,下一个问题是如何区分它们?
Query Sparsity Measurement
从公式 (1) 出发,第 i i i 个查询对所有键的注意力被定义为一个概率 p ( k j ∣ q i ) p(k_j|q_i) p ( k j ∣ q i ) ,输出是这些概率与值 v v v 的组合。主导的点积对促进了相应查询的注意力概率分布远离均匀分布。如果 p ( k j ∣ q i ) p(k_j|q_i) p ( k j ∣ q i ) 接近于均匀分布 q ( k j ∣ q i ) = 1 / L K q(k_j|q_i) = 1/L_K q ( k j ∣ q i ) = 1/ L K ,则自注意力成为简单的值 V V V 的和,对输入过程而言是冗余的。本质上可以通过分布 p p p 和 q q q 的“相似性”来区分“重要的”查询。我们通过 Kullback-Leibler 散度来测量“相似性”:
K L ( q ∥ p ) = ln ∑ l = 1 L K e q i k l ⊤ / d − 1 L K ∑ j = 1 L K q i k j ⊤ / d − ln L K KL(q\|p) = \ln \sum_{l=1}^{L_K} e^{q_i k_l^\top / \sqrt{d}} - \frac{1}{L_K} \sum_{j=1}^{L_K} q_i k_j^\top / \sqrt{d} - \ln L_K K L ( q ∥ p ) = ln l = 1 ∑ L K e q i k l ⊤ / d − L K 1 j = 1 ∑ L K q i k j ⊤ / d − ln L K
去掉常数项后,我们定义第 i i i 个查询的稀疏性测度 为:
M ( q i , K ) = ln ∑ j = 1 L K e q i k j ⊤ d − 1 L K ∑ j = 1 L K q i k j ⊤ d ( 2 ) M(q_i, \mathbf{K}) = \ln \sum_{j=1}^{L_K} e^{\frac{{q_i k_j^\top}}{\sqrt{d}}} - \frac{1}{L_K} \sum_{j=1}^{L_K} \frac{q_i k_j^\top}{\sqrt{d}} \quad (2) M ( q i , K ) = ln j = 1 ∑ L K e d q i k j ⊤ − L K 1 j = 1 ∑ L K d q i k j ⊤ ( 2 )
其中,第一个项是 q i q_i q i 对所有键的 Log-Sum-Exp (LSE),第二个项是它们的算术平均值。如果第 i i i 个查询获得更大的 M ( q i , K ) M(q_i, \mathbf{K}) M ( q i , K ) ,则其 attention probability p 更具区别性,且更有可能包含长尾 self-attention 分布的头部区域中的主要点积对。如下图红框区域:
ProbSparse Self-attention
ProbSparse 自注意力 基于所提出的度量方法,我们通过允许每个键仅关注于 u u u 个主要query来实现 ProbSparse 自注意力:
A ( Q , K , V ) = Softmax ( Q ‾ K ⊤ d ) V ( 3 ) \mathcal{A}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax}\left(\frac{\overline{\mathbf{Q}}\mathbf{K}^\top}{\sqrt{d}}\right)\mathbf{V} \quad (3) A ( Q , K , V ) = Softmax ( d Q K ⊤ ) V ( 3 )
其中,Q ‾ \overline{\mathbf{Q}} Q 是与 q \mathbf{q} q 尺寸相同的稀疏矩阵,它仅包含在稀疏性度量 M ( q , K ) M(\mathbf{q}, \mathbf{K}) M ( q , K ) 下的前 u u u 个查询。通过常数采样因子 c c c 控制,我们设置 u = c ⋅ ln L Q u = c \cdot \ln L_Q u = c ⋅ ln L Q ,这使得 ProbSparse 自注意力仅需为每个查询-键查找计算 O ( ln L Q ) \mathcal{O}(\ln L_Q) O ( ln L Q ) 次点积,且层的内存使用维持在 O ( L K ln L Q ) \mathcal{O}(L_K \ln L_Q) O ( L K ln L Q ) 。在多头机制的视角下,这种注意力为每个头生成不同的稀疏查询-键对,从而避免了严重的信息丢失。
然而,遍历所有查询以进行度量 M ( q i , K ) M(q_i, \mathbf{K}) M ( q i , K ) 需要计算每个点积对,即 O ( L Q L K ) \mathcal{O}(L_Q L_K) O ( L Q L K ) 的平方复杂度,此外 LSE 操作可能会引发潜在的数值稳定性问题。受此启发,我们提出了一种经验近似方法,以有效地获取查询稀疏性度量。
基于引理1(这里跳过),我们提出了最大均值测度:
M ‾ ( q i , K ) = max j { q i k j ⊤ d } − 1 L K ∑ j = 1 L K { q i k j ⊤ d } \overline{M}(q_i, \mathbf{K}) = \max_j \left\{\frac{q_i k_j^\top}{\sqrt{d}}\right\} - \frac{1}{L_K} \sum_{j=1}^{L_K} \left\{\frac{q_i k_j^\top}{\sqrt{d}}\right\} M ( q i , K ) = j max { d q i k j ⊤ } − L K 1 j = 1 ∑ L K { d q i k j ⊤ }
Top-u 的范围大致在命题1 (这里跳过) 的边界放宽中成立。在长尾分布下,我们只需要随机采样 U = L K ln L Q U = L_K \ln L_Q U = L K ln L Q 点积对来计算 M ‾ ( q i , K ) \overline{M}(q_i, \mathbf{K}) M ( q i , K ) ,即把其他对填充为零。然后,我们从中选择稀疏的前 u u u 个作为 Q ‾ \overline{\mathbf{Q}} Q 。M ‾ ( q i , K ) \overline{M}(q_i, \mathbf{K}) M ( q i , K ) 中的最大运算符对零值不太敏感且数值稳定。在实践中,查询和键的输入长度通常在自注意力计算中是相等的,即 L Q = L K = L L_Q = L_K = L L Q = L K = L ,因此总的 ProbSparse 自注意力时间复杂度和空间复杂度均为 O ( L ln L ) \mathcal{O}(L \ln L) O ( L ln L ) 。
Encoder:引入自注意力蒸馏机制(Self-attention Distilling)
编码器:在内存使用限制下,允许处理更长序列的输入。
该编码器旨在提取长序列输入的鲁棒长程依赖性。在输入表示之后,第 t t t 个序列输入 X t \mathbf{X}^t X t 被形成为一个矩阵 X en t ∈ R L x × d model \mathbf{X}_{\text{en}}^t \in \mathbb{R}^{L_x \times d_{\text{model}}} X en t ∈ R L x × d model 。我们在图3中给出编码器的草图以供说明。
自注意力蒸馏 作为 ProbSparse 自注意力机制的自然结果,编码器的特征图存在值 V V V 的冗余组合。我们使用蒸馏操作来突出具有主导特征的优越特征,并在下一层中生成一个集中的自注意力特征图。它锐化裁剪了输入的时间维度,查看图3中注意力块的 n n n -头加权矩阵(重叠的红色方框)。受到稀疏卷积的启发(Yu, Koltun, and Funkhouser 2017; Gupta and Rush 2017),我们的“蒸馏”过程从第 j j j 层转移到第 ( j + 1 ) (j+1) ( j + 1 ) 层为:
X j + 1 t = MaxPool ( ELU ( Conv1d ( [ X j t ] A B ) ) ) \mathbf{X}^{t}_{j+1} = \text{MaxPool}\left(\text{ELU}\left(\text{Conv1d}([\mathbf{X}_j^t]_{AB})\right)\right) X j + 1 t = MaxPool ( ELU ( Conv1d ([ X j t ] A B ) ) )
其中 [ ⋅ ] A B [\cdot]_{AB} [ ⋅ ] A B 表示注意力块。它包含了多头 ProbSparse 自注意力和基本操作,其中 Conv1d(⋅ \cdot ⋅ ) 在时间维度上执行一维卷积滤波(内核宽度=3),并使用 ELU(⋅ \cdot ⋅ ) 激活函数(Clevert, Unterthiner, and Hochreiter 2016)。我们在层堆叠之后添加一个跨度为2的最大池化层,将样本 X t \mathbf{X}^t X t 向下采样到其一半切片,这减少了整体内存使用至 O ( ( 2 − ϵ ) L log L ) \mathcal{O}((2-\epsilon)L\log L) O (( 2 − ϵ ) L log L ) ,其中 ϵ \epsilon ϵ 是一个小数。为了增强蒸馏操作的鲁棒性,我们构建主要堆叠的副本,用半量输入,并通过一次丢弃一层的方式逐步减少自注意力蒸馏层的数量,如图2中的金字塔,使得输出尺寸对齐。因此,我们连接所有堆叠的输出,并获得编码器的最终隐藏表示。
Decoder:一次性生成长序列输出
解码器:通过一次前向过程生成长序列输出
我们在图2中使用了一个标准的解码器结构(Vaswani et al. 2017),它由两个相同的多头注意力层堆叠而成。然而,我们采用生成推理以缓解长预测中的速度下降。我们将以下向量输入到解码器中:
X de t = Concat ( X token t , X 0 t ) ∈ R ( L token + L y ) × d model \mathbf{X}_{\text{de}}^t = \text{Concat}(\mathbf{X}_{\text{token}}^t, \mathbf{X}_0^t) \in \mathbb{R}^{(L_{\text{token}}+L_y)\times d_{\text{model}}} X de t = Concat ( X token t , X 0 t ) ∈ R ( L token + L y ) × d model
其中,X token t ∈ R L token × d model \mathbf{X}_{\text{token}}^t \in \mathbb{R}^{L_{\text{token}} \times d_{\text{model}}} X token t ∈ R L token × d model 是起始标记,X 0 t ∈ R L y × d model \mathbf{X}_0^t \in \mathbb{R}^{L_y \times d_{\text{model}}} X 0 t ∈ R L y × d model 是作为目标序列的占位符(设定为标量0)。在 ProbSparse 自注意力计算中应用了掩码多头注意力,通过将掩码点积设置为 − ∞ -\infty − ∞ 。这防止了每个位置关注即将到来的位置,从而避免了自回归。一个全连接层获取最终输出,其输出大小 d y d_y d y 取决于我们执行的是单变量预测还是多变量预测。
生成推理 起始标记在NLP的“动态解码”(Devlin et al. 2018)中被有效应用,并且我们以生成方式扩展它。我们不是选择特定标记作为标志,而是从输入序列中采样一个长为 L token L_{\text{token}} L token 的序列,比如输出序列之前的一个较早的切片。以预测168个点(实验部分的7天温度预测)为例,我们将目标序列之前已知的5天作为“起始标记”,并将生成式推理解码器输入 X de = { X 5 d , X 0 } \mathbf{X}_{\text{de}} = \{\mathbf{X}_{5d}, \mathbf{X}_0\} X de = { X 5 d , X 0 } 。X 0 \mathbf{X}_0 X 0 包含目标序列的时间戳,即目标周的上下文。然后,我们提出的解码器通过一次前向过程来预测输出,而不是在传统的编码-解码架构中耗时的“动态解码”。详细的性能比较在计算效率部分给出。
损失函数 我们选择均方误差(MSE)损失函数用于目标序列的预测,损失通过整个模型的解码器输出进行反向传播。