深度学习点点滴滴— Transformer 中位置编码(上)

1,041 阅读4分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第22天,点击查看活动详情

前言

Transformer 架构是由 Vaswani 等人提出的,是一种新颖的基于自注意序机制的序列到序列的架构,支持并行训练,同时具有很强的学习能力,凭借其良好的泛化能力,在 NLP 和 CV 等领域中,已经成为成为首选的基础模型。

大多数框架都提供了对 Transformer 的实现,省去了我们基于 paper 一步一步实现 Transformer 的时间和精力,可以轻轻松松地拿框架提供 transformer 模型去尝试、试验。不过这样一来也让我们原理模型结构的细节,让我们逐渐失去一种能力、这种能力就是我们通过深入研究、剖析有一天写出自己的模型的能力。

为什么需要词序

首先我们先来回答第一个问题,也就是为什么需要词序,词序是一句话能够正确表达其含义不可缺少的部分,不同词序对于一堆词可能表达出意思大相径庭。在 RNN 结构中,本身天生就是支持词序的,因为每一次推理都需要包含上一个隐含状态,在 RNN 中逐次

带着问题去看位置编码

在开始前我们想问两个问题?

  • 为什么需要位置编码
  • 为什么对于词嵌向量可以通过相加形式融入位置编码
  • 那么 transformer 又是如何从从加入位置编码后从中学习到位置编码信息的呢

希望能够在文章中找到以上问题的答案

位置编码

在开始介绍 transformer 论文中的位置编码方式前,我们可以先来自己想一想如果让我们自己去设计一个有效位置编码应该如何设计呢?

为了让模型可识别出每一个词在句子中位置,可以通过添加位置信息到序列中每一个词,这个能够提供位置信息就是位置编码。

通常我们会想到给每个词向量都添加有序数,例如给第一位置上词添加 1,给第 2 词添加 2 依次类推作为位置编码。这样做的问题就是可能会添加一个比较大值,从而破坏原有词向量的含义。而且当遇到句子长度超过训练集中最长的句子,模型就变得束手无措,从而影响了模型的泛化能力。

可能还会想到,让位置编码就都落在了 [0,1] 范围内,但是这样做依然存在问题,就是不同长度句子的位置编码的步长是不一致的。

  • 在一个句子中,每个词对应位置编码是唯一的
  • 两个词之间的距离大小和句子的长度无关
  • 不需要做额外的工作,就可以推广到更长句子
  • 确定性

在论文中,作者提出位置编码可以说是简单有效的方式,不得不为之感叹,感叹作者的精巧设计。位置编码并不是一个简单的数字,而是一个 dmodeld_{model} 的向量,维度和 embedding 向量保持一致。位置编码向量包含位置信息,可以将位置编码向量通过相加方式融合到 embedding 向量,并且这些位置编码是可以被模型所感知识别。

假设一个词在矩阵位置为 tt,那么用 ptRd\vec{p}_t \in \mathbb{R}^d 这个 dd 为向量维度,通常和 embedding 向量维度一致。那么 f:NRdf : \mathbb{N} \rightarrow \mathbb{R}^d

pt=f(t)(i)\vec{p}_t = f(t)^{(i)}
sin(wkt)  if  i=2kcos(wkt)  if  i=2k+1\sin(w_k\cdot t)\;if\; i=2k\\ \cos(w_k\cdot t) \; if\; i = 2k + 1
wk=1100002k/dw_k = \frac{1}{10000^{2k/d}}

以上就是完整的公式

pt=[sin(w1t)cos(w1t)sin(w2t)cos(w2t)sin(wd/2t)cos(wd/2t)]\vec{p}_t = \begin{bmatrix} \sin(w_1\cdot t)\\ \cos(w_1\cdot t)\\ \\ \sin(w_2\cdot t)\\ \cos(w_2\cdot t)\\ \vdots\\ \sin(w_{d/2}\cdot t)\\ \cos(w_{d/2}\cdot t)\\ \end{bmatrix}

可以看出对于位置来说不同维度上频率是不同的,这样上面 f(t)(i)f(t)^{(i)} 输出向量就是每一个位置向量都是对应一个位置编码,然后接下来我们再去看一看是如何实现了相对位置。

位置编码

看到上面公式可能会有这样疑惑,就是正弦和余弦的组合是怎么来表示一个位置的呢?其实很简单,假设用二进制格式表示一个数字,那会是怎样的?

屏幕快照 2022-08-17 上午11.12.37.png

我们都知道在计算机中,用二进制来表示一个数字,也就是将数字表示为 1 和 0 按照一定规则这和,那么可以 sin 和 cos 交替。我们列(同一颜色)来看0 和 1 交替出现频率从红色到橙色在逐渐减少

positional_encoding_001.png

对于浮点数来说,使用二进制值将是一种空间浪费,可以使用正弦函数来表示连续的浮点数,可以将正弦函数理解为交替的比特值。