transformer实战——模型的输入到底是什么?

117 阅读4分钟

1 输入数据经过input embedding的输出形式

1.1 训练的transformer的功能与初始的输入句子

假设我们训练的transformer是德语转英语的功能。 训练数据为两条德语的句子。分别如下:

  1. ich mochte ein bier P
  2. ich mochte ein cola P

p是填充左右,至于为什么要加p之后再说。 我们建立输入的过程如下:

1.2 开始手动处理句子

import torch
import torch.nn as nn
import torch.utils.data as Data

1.2.1 步骤一,定义sentences

和传统的神经网络一样,输入的数据分为训练集和测试集。 训练集是包含输入数据和真实值。真实值主要用来反向传播,与模型的预测的结果比较,然后不断调整模型的参数,让模型能够提升自己的能力。 测试集只包含输入数据。

# enc_input 为训练时,喂给encode的输入
# dec_input 为训练时,喂给decode的输入
# dec_output 为训练时,模型decode的输出,为模型真实的输出,用于计算loss
sentences = [
        # enc_input           dec_input         dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

1.2.2 步骤二,建立字库

# 这里是人为的建立字库,因为训练数据较少,就直接自己敲了
# 建德语字库
# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
src_vocab_size = len(src_vocab)  #德语字库长度

# 建英语字库
tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
tgt_vocab_size = len(tgt_vocab) #英语字库的长度

1.2.3 步骤三,符号化句子

# 负责把句子转为数字,符号化
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
      dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
      dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

enc_inputs
tensor([[1, 2, 3, 4, 0],
        [1, 2, 3, 5, 0]])
dec_inputs
tensor([[6, 1, 2, 3, 4, 8],
        [6, 1, 2, 3, 5, 8]])
dec_outputs
tensor([[1, 2, 3, 4, 8, 7],
        [1, 2, 3, 5, 8, 7]])

1.2.4 步骤四,打包数据

# DataLoader(...)是yTorch 提供的工具:自动按批(batch)划分数据
# 负责把数据按batch_size为2来打包,方便后续操作
class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs
  
  def __len__(self):
    return self.enc_inputs.shape[0]
  
  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
# 2表示每个 batch 的大小为 2。True表示是否打乱数据(shuffle)
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
# 因为总共就两组训练数据。batch的大小又是2,所以只有一个batch的数据,即 把两组句子打包为了一个batch。
for batch in loader:
    enc_in, dec_in, dec_out = batch
    print(enc_in.shape)
    print(dec_in.shape)
    print(dec_out.shape)
torch.Size([2, 5])
torch.Size([2, 6])
torch.Size([2, 6])

1.2.5 步骤五,变为词向量

# 步骤五,把每个bathc中的数据都变为词的向量,即把“词的编号”输入嵌入层 → 得到“词的向量”表示
d_model = 512  #  每个词要被映射成多长的向量,比如 d_model=512 就是把每个词转成 512 维向量。
src_emb = nn.Embedding(src_vocab_size, d_model)  # embedding vector

# 查看词向量表(所有词的向量)
print(src_emb.weight.shape)  # 输出: torch.Size([6, 512])
# 查看某个具体词(比如 index = 3)的向量
print(src_emb.weight[3])  # 输出: tensor([...512个值...])

#src_emb它内部其实维护了一个大小为 [src_vocab_size, d_model] 的词向量矩阵
# 那么向src_emb传入 [3, 8, 102]:
# 第 3 个词 → 查 src_emb.weight[3] → 得到 512维向量
# 第 8 个词 → 查 src_emb.weight[8] → 得到 512维向量
# 第 102 个词 → → 查 src_emb.weight[102] → 得到 512维向量
# 组合成一个输出张量:embedded.shape = [1, 3, 512]。表示:1 个句子(batch),3 个词,每个词是一个 512 维向量
torch.Size([6, 512])
tensor([-1.1477, -0.3596,  0.1453, -0.6287, -1.0734, -0.3935, -1.7570,  0.9578,
         0.0962, -0.8109, -1.0551, -0.1319,  0.4090, -0.0348, -0.0409, -0.8494,
        -0.9570,  0.1087, -0.8971, -1.7647,  0.0482,  0.5101, -0.6541,  0.3953,
         0.4137, -0.1816, -2.5110, -0.1869,  0.2335,  1.2381, -0.6446,  0.1736,
        -0.7933,  0.0947, -1.4920, -1.0351, -0.1584,  0.5599, -0.4506,  0.4465,
         0.1421, -0.4126, -1.1446,  0.6944,  0.3732, -0.7835,  1.3148,  0.9072,
        -0.5693,  1.3499, -0.2509,  1.0632, -0.4607, -0.4230,  1.3969, -0.8768,
        -0.2439,  0.2778, -0.8335, -0.8099,  0.2877, -1.0764,  1.3982,  0.2309,
        -0.1010, -0.5357, -0.9890,  0.0331, -0.3475,  1.5302, -0.0182, -0.1234,
         1.0533, -0.6214, -2.2774, -0.3835, -1.0880, -0.1183, -0.5182,  0.4344,
        -0.0451, -0.4558,  1.5517,  0.6882,  0.2474, -0.4487, -0.7173, -1.6592,
        -0.0667,  0.6661, -0.2130, -0.0576, -0.5461, -0.0718,  0.5827, -0.3371,
        -0.3426, -0.1747, -0.0688,  1.2108,  0.4492, -1.7266,  0.6804,  0.6227,
         1.3680, -0.0567, -0.8870, -0.6696,  0.1879,  0.1069,  0.3404, -0.0251,
         0.0833, -0.3172, -0.8319, -0.3646,  0.5962, -0.3971,  0.7568,  0.5467,
        -0.2688, -1.2597, -0.4271, -1.2136,  0.5975, -1.4806, -0.3089, -1.2802,
         1.2744, -1.0499,  0.9727, -1.4615,  0.3897,  0.3241, -0.0627,  0.5804,
         0.3460, -0.8262, -1.2245, -0.3822, -0.2669, -0.6048,  1.2361, -0.2094,
        -1.8484, -0.1826, -1.0867,  1.1297, -0.0669, -0.6465, -0.6890,  1.1864,
        -0.6725, -1.1461,  0.4714,  1.8949,  0.4749, -0.4418, -0.2213,  0.9531,
         1.5374,  1.2361, -0.2521,  0.2188, -2.2523,  0.3806,  0.4727,  0.1025,
        -0.5193,  0.5104,  1.4145,  0.6588, -2.7046, -1.4447,  1.7596,  0.5135,
        -1.3575, -0.7540, -0.1853,  0.6450,  0.3702, -1.1559, -0.6695, -1.0101,
         0.3161,  2.3780,  0.3959, -1.1254, -1.2363, -0.4276, -0.2829, -0.3608,
         0.5598,  0.7834,  0.2657,  0.6680, -0.3158,  0.2259, -1.5649, -0.0994,
        -0.1219,  0.7391,  1.6350, -0.3758, -1.1995, -0.6225, -0.5865,  0.9992,
         0.2492,  0.7824,  1.4608, -0.5715,  1.3410,  0.9098,  0.1645, -0.4211,
        -1.1391, -0.3085, -0.5740,  0.0536,  0.5054,  1.2486,  0.1797,  0.1772,
        -0.0077,  0.8421, -0.2567, -0.7981, -1.1647,  0.7046,  0.6566,  1.1096,
        -1.5317, -1.4469,  0.4324, -1.2202,  0.5087, -0.9350,  1.0088, -0.6522,
        -0.6852, -0.0676, -0.4069,  0.7516,  1.3314, -0.0385, -0.7060, -0.1177,
        -0.4652,  0.9235,  1.4318, -0.7186, -0.9823,  0.0983,  0.8116,  2.0150,
         0.7832, -0.7332,  0.2785,  0.8650,  0.0659,  1.9205,  0.6896, -1.7097,
        -0.5481, -0.8304, -0.8532, -0.9558, -1.2862,  0.1529,  0.1802, -0.3222,
         0.6690,  0.6723, -1.6274, -0.2916,  0.8425, -0.2246,  0.1587, -0.9713,
         2.2354,  0.8762,  0.4638,  0.5241,  0.4882,  1.5547, -0.2164, -1.5353,
        -0.2817, -0.4137,  0.2668, -1.8575, -0.7586,  0.3339, -1.1832, -0.7198,
        -0.7718, -1.1368, -1.4015,  0.7338, -0.2357,  0.5403,  0.2786,  0.4765,
         0.3082,  0.3069, -1.9527, -1.9057, -0.2212,  0.8627,  0.4699, -0.6659,
        -0.2734, -2.7393,  0.3431,  0.8503, -0.6571,  0.7144, -0.0308, -0.9923,
         0.0575, -1.1740, -0.6545, -1.7060,  0.7489, -0.3493, -1.3976,  0.6173,
         0.0171,  1.4756, -0.3271, -0.6145, -0.2688,  1.7774,  0.2857,  0.6992,
        -0.2516, -2.8353, -0.7691,  0.0602, -0.9851,  0.2797,  0.2746, -0.5686,
        -0.7922,  1.5187,  0.8221,  1.0336, -0.5249,  0.8031,  1.1510, -0.0794,
        -0.9719, -2.6015, -0.1009,  0.8683, -0.8088,  0.1053,  0.2380,  1.4279,
         0.9191,  0.4270, -0.6779, -0.8884, -0.4285,  0.4932,  0.9706, -0.0214,
        -0.1817, -0.0908,  0.8925, -2.6058, -0.2539, -0.4734,  1.9980, -2.4311,
        -0.0257, -0.9688,  0.2406, -0.5834,  1.2935, -0.7192, -0.7562,  2.0201,
         0.4671, -1.2815,  1.0332, -1.4040,  1.8302, -0.9846, -2.7982, -1.2169,
         1.1619,  0.0261, -0.4865, -0.2836,  0.6362,  2.0568, -1.0484, -0.4122,
         1.6375,  0.4231,  0.1885,  1.3461, -0.1934, -0.6927,  1.2993,  0.2502,
        -0.4953,  1.6585, -1.1208,  1.1018, -1.4452,  1.1016,  0.0979, -0.9619,
         0.1588,  0.3976,  0.5759, -0.1950, -1.0084,  0.0685, -0.8129, -0.4497,
        -1.1718,  0.3306,  0.4870, -1.2009,  0.3477, -0.2078, -0.8105,  0.3183,
        -0.8205,  1.2886,  1.2331,  0.2287,  0.5842,  0.2788, -0.5188, -2.8053,
        -0.6526,  0.3252, -0.6878, -0.3900,  0.8954,  0.6639,  0.9548,  0.3070,
        -1.2617, -0.3881,  0.5237,  0.3399,  0.1282, -1.0472, -0.0615,  1.5472,
         0.8922, -0.6291,  0.2759, -0.2442, -0.0764,  1.5333, -1.3812,  0.3255,
         1.3294,  0.5463,  1.0385,  0.6635,  1.2749,  0.2723,  0.3839, -0.7755,
        -1.1042,  1.3063, -1.5888, -0.3427, -1.1477,  0.9240, -1.2547, -1.3169,
         1.0847,  0.6008,  0.2272,  0.0049,  0.9528,  2.3677, -1.1434, -1.1723,
        -0.0300, -0.8270, -0.5220, -0.9690,  0.9291, -1.1758, -0.5836, -0.9810,
        -1.2323,  0.2687, -0.0971,  0.8811,  1.2048,  1.2465,  2.9661, -0.8329,
        -0.2704, -0.1134, -1.1699,  0.2722,  1.2309,  0.0519, -1.4415,  0.8898],
       grad_fn=<SelectBackward0>)
# 就是把每个词的 ID → 查表 → 找到它对应的向量 → 输出词向量序列
# 因为enc_in的shape为torch.Size([2, 5]),所以输出的shape为torch.Size([2, 5, 512])
enc_in = src_emb(enc_in)
print(enc_in.shape)
torch.Size([2, 5, 512])

1.2.6 [2,5,512]的含义

  1. 2:表示 batch size(批次大小),即同时处理的独立序列(例如句子)的数量。所以,这里是两个句子。

  2. 5:表示每个序列的长度(或词/token 的数量)。每个句子都有五个词。

  3. 512:表示每个词的词向量维度(或特征大小)。每个词都由一个 512 维的向量来表示。

这就是一开始我们输入的句子经过 Input EMbedding之后的形式

标注经过input embedding 的输出.png

1.3 参考

Transformer的PyTorch实现