transformer实战——Positional Encoding到底是什么样的?

136 阅读5分钟

1位置编码的理论部分

1.1 Positional Encoding的知识点

  1. 位置编码是根据公式一次性计算出来的,它不是训练参数,也不是需要反向传播优化的内容。它就像一本查找表,里面存储了从位置 0 到 max_len - 1 的所有位置的编码。
  2. 位置编码需要参数如下
    • 一个参数 d_model,因为要和Embedding之后得到的词向量相加(不是拼接,就是单纯的对应位置上的数值进行加和),所以需要长度一致,均为d_model。
    • 一个参数max_len,如果模型要处理的句子长度最长是 500 个词。位置编码需要为这 500 个词的每一个位置(从第 0 个词到第 499 个词)都生成一个独特且有规律的编码。max_len 就决定了需要预先计算多少个这样的位置编码。
  3. 需要使用位置嵌入的原因也很简单,因为 Transformer 摈弃了 RNN 的结构,因此需要一个东西来标记各个字之间的时序 or 位置关系,而这个东西,就是位置嵌入,通过注入每个字位置信息的方式,增强了模型的输入(其实说白了就是将位置嵌入和字嵌入相加,然后作为输入)。

1.2 Positional Encoding 的公式

pe公式如下

image.png

  1. 其中pos表示词在句子中的绝对位置(例如,第0个词,第1个词,等等)。
  2. 其中i 指的是位置编码向量内部的维度序号
    • 对于位置编码向量的偶数维度(0, 2, 4, ...):使用 sin 函数。
    • 对于位置编码向量的奇数维度(1, 3, 5, ...):使用 cos 函数。
    • i 实际上是在遍历位置编码向量的维度对(一个偶数维度和一个奇数维度)
    • 即如以下公式 image.png
  3. 为什么需要i
    • i使得每个位置编码向量的不同维度能够捕捉到不同频率的周期信息
    • i 如何影响频率,如下图所示 image.png
    • 这意味着位置编码向量的靠前维度(低索引维度,如第 0、1 维)会捕获到高频率的周期信息。这些维度对位置变化的敏感度更高,适合捕捉细微的位置差异,因为频率高 → 周期短 → 曲线变化剧烈, 当 pos 值发生微小变化时(例如从 pos 变为 pos+1),这些维度的位置编码值会发生明显的、可区分的变化。因此,它们对位置变化“敏感”,擅长捕捉序列中的局部、细微的位置信息
    • 位置编码向量的靠后维度会捕获到低频率的周期信息。这些维度对位置变化的敏感度较低,适合捕捉大的位置跨度。频率低,周期长,曲线变化平缓。它们对 pos 的微小变化不敏感,但能捕捉到更长距离的位置关系,适合表示全局的位置信息。

1.3 Positional Encoding的呈现

下面画一下位置嵌入,纵向观察,可见随着 embedding_dimension序号增大,位置嵌入函数的周期变化越来越平缓

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math

def get_positional_encoding(max_seq_len, embed_dim):
    # 初始化一个positional encoding
    # embed_dim: 字嵌入的维度
    # max_seq_len: 最大的序列长度
    positional_encoding = np.array([
        [pos / np.power(10000, 2 * i / embed_dim) for i in range(embed_dim)]
        if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])
    
    positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2])  # dim 2i 偶数
    positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2])  # dim 2i+1 奇数
    return positional_encoding

positional_encoding = get_positional_encoding(max_seq_len=100, embed_dim=16)
plt.figure(figsize=(10,10))
sns.heatmap(positional_encoding)
plt.title("Sinusoidal Function")
plt.xlabel("hidden dimension")
plt.ylabel("sequence length")
Text(95.72222222222221, 0.5, 'sequence length')

image.png

2 Positional Encoding的实践部分

2.1公式的转化

由于pytorch无法支持axa^x ,但是有torch.exp()计算exe^x ,所以我们把上述公式转换为如下形式

image.png

2.2 具体实施

import torch
max_len = 5000
d_model = 512
# 初始化为全0
pe = torch.zeros(max_len, d_model)
pe
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
# 创建 [0, 1, 2, ..., max_len-1] 的位置索引,但是dtype=torch.float,所以是以浮点数和科学计数法的形式展现,完全可以把它们当作整数来使用。
position = torch.arange(0, max_len, dtype=torch.float)
# 并转成 [max_len, 1] 的形状,表示每一行对应一个位置。
position = position.unsqueeze(1)
position
tensor([[0.0000e+00],
        [1.0000e+00],
        [2.0000e+00],
        ...,
        [4.9970e+03],
        [4.9980e+03],
        [4.9990e+03]])
# arange(0, 6, 2)->[0, 2, 4]  即对应维度 0, 2, 4
# 转公式为以e为底,512维度的对应的分母
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
div_term
tensor([1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
        8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
        6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
        5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
        4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
        3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
        2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
        2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
        1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
        1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
        1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
        9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7736e-02,
        7.4989e-02, 7.2339e-02, 6.9783e-02, 6.7317e-02, 6.4938e-02, 6.2643e-02,
        6.0430e-02, 5.8294e-02, 5.6234e-02, 5.4247e-02, 5.2330e-02, 5.0481e-02,
        4.8697e-02, 4.6976e-02, 4.5316e-02, 4.3714e-02, 4.2170e-02, 4.0679e-02,
        3.9242e-02, 3.7855e-02, 3.6517e-02, 3.5227e-02, 3.3982e-02, 3.2781e-02,
        3.1623e-02, 3.0505e-02, 2.9427e-02, 2.8387e-02, 2.7384e-02, 2.6416e-02,
        2.5483e-02, 2.4582e-02, 2.3714e-02, 2.2876e-02, 2.2067e-02, 2.1288e-02,
        2.0535e-02, 1.9810e-02, 1.9110e-02, 1.8434e-02, 1.7783e-02, 1.7154e-02,
        1.6548e-02, 1.5963e-02, 1.5399e-02, 1.4855e-02, 1.4330e-02, 1.3824e-02,
        1.3335e-02, 1.2864e-02, 1.2409e-02, 1.1971e-02, 1.1548e-02, 1.1140e-02,
        1.0746e-02, 1.0366e-02, 1.0000e-02, 9.6466e-03, 9.3057e-03, 8.9769e-03,
        8.6596e-03, 8.3536e-03, 8.0584e-03, 7.7736e-03, 7.4989e-03, 7.2339e-03,
        6.9783e-03, 6.7317e-03, 6.4938e-03, 6.2643e-03, 6.0430e-03, 5.8294e-03,
        5.6234e-03, 5.4247e-03, 5.2330e-03, 5.0481e-03, 4.8697e-03, 4.6976e-03,
        4.5316e-03, 4.3714e-03, 4.2170e-03, 4.0679e-03, 3.9242e-03, 3.7855e-03,
        3.6517e-03, 3.5227e-03, 3.3982e-03, 3.2781e-03, 3.1623e-03, 3.0505e-03,
        2.9427e-03, 2.8387e-03, 2.7384e-03, 2.6416e-03, 2.5483e-03, 2.4582e-03,
        2.3714e-03, 2.2876e-03, 2.2067e-03, 2.1288e-03, 2.0535e-03, 1.9810e-03,
        1.9110e-03, 1.8434e-03, 1.7783e-03, 1.7154e-03, 1.6548e-03, 1.5963e-03,
        1.5399e-03, 1.4855e-03, 1.4330e-03, 1.3824e-03, 1.3335e-03, 1.2864e-03,
        1.2409e-03, 1.1971e-03, 1.1548e-03, 1.1140e-03, 1.0746e-03, 1.0366e-03,
        1.0000e-03, 9.6466e-04, 9.3057e-04, 8.9769e-04, 8.6596e-04, 8.3536e-04,
        8.0584e-04, 7.7736e-04, 7.4989e-04, 7.2339e-04, 6.9783e-04, 6.7317e-04,
        6.4938e-04, 6.2643e-04, 6.0430e-04, 5.8294e-04, 5.6234e-04, 5.4247e-04,
        5.2330e-04, 5.0481e-04, 4.8697e-04, 4.6976e-04, 4.5316e-04, 4.3714e-04,
        4.2170e-04, 4.0679e-04, 3.9242e-04, 3.7855e-04, 3.6517e-04, 3.5227e-04,
        3.3982e-04, 3.2781e-04, 3.1623e-04, 3.0505e-04, 2.9427e-04, 2.8387e-04,
        2.7384e-04, 2.6416e-04, 2.5483e-04, 2.4582e-04, 2.3714e-04, 2.2876e-04,
        2.2067e-04, 2.1288e-04, 2.0535e-04, 1.9810e-04, 1.9110e-04, 1.8434e-04,
        1.7783e-04, 1.7154e-04, 1.6548e-04, 1.5963e-04, 1.5399e-04, 1.4855e-04,
        1.4330e-04, 1.3824e-04, 1.3335e-04, 1.2864e-04, 1.2409e-04, 1.1971e-04,
        1.1548e-04, 1.1140e-04, 1.0746e-04, 1.0366e-04])
# 对于pe的每行,求从第0列开始,每隔2列的值
pe[:, 0::2] = torch.sin(position * div_term)  
# 对于pe的每行,求从第1列开始,每隔2列的值
pe[:, 1::2] = torch.cos(position * div_term)
# pe 的形状是 [max_len, d_model]
pe
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,          2.0733e-04,  1.0000e+00],
        ...,
        [ 9.5625e-01, -2.9254e-01,  9.3594e-01,  ...,  8.5926e-01,          4.9515e-01,  8.6881e-01],
        [ 2.7050e-01, -9.6272e-01,  8.2251e-01,  ...,  8.5920e-01,          4.9524e-01,  8.6876e-01],
        [-6.6395e-01, -7.4778e-01,  1.4615e-03,  ...,  8.5915e-01,          4.9533e-01,  8.6871e-01]])
# unsqueeze(0) 的作用是在张量的第 0 个维度(最前面)添加一个大小为 1 的新维度。
# 执行这一步后,pe 的形状从 [max_len, d_model] 变为 [1, max_len, d_model]
pe = pe.unsqueeze(0)
# transpose(0, 1) 的作用是交换张量的第 0 个维度和第 1 个维度。
# 操作后,pe 的形状从 [1, max_len, d_model] 变为 [max_len, 1, d_model]
pe = pe.transpose(0, 1)
# 最终的形状 [max_len, 1, d_model] 是为了与 Transformer 模型的典型输入形状 [seq_len, batch_size, d_model] 进行兼容。

参考

  1. Transformer 中的 Positional Encoding
  2. Transformer的PyTorch实现