[Transformer] 位置编码层记得要自适应输入的维度

452 阅读2分钟

一开始别人的位置编码模块是这样的

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 初始化 Shape 为 (max_len, d_model) 的 PE (positional encoding)
        pe = torch.zeros(max_len, d_model)
        # 初始化一个 tensor [[0, 1, 2, 3, ...]]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # 这里就是 sin 和 cos 括号中的内容,通过 e 和 ln 进行了变换
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # 计算 PE(pos, 2i)
        pe[:, 0::2] = torch.sin(position * div_term)
        # 计算 PE(pos, 2i+1)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 为了方便计算,在最外面在 unsqueeze 出一个 batch
        # pe: [max_len, d_model] -> [batch_size=1, max_len, d_model]
        #pe = pe.unsqueeze(0)
        # 但是我不需要这个 batch_size 的话就不用在前面插入一个维度了

        # 如果一个参数不参与梯度下降,但又希望保存 model 的时候将其保存下来
        # 这个时候就可以用 register_buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [enc_seq_len, enc_features_size]
        # pe: [max_len, d_model]
        # print('self.pe.size(): {}'.format(self.pe.size()))
        # pe = torch.unsqueeze(pe, 1)
        x = x + self.pe[: x.size(0)].requires_grad_(False)
        return self.dropout(x)

问题在于这里 x = x + self.pe[: x.size(0)].requires_grad_(False) 没有自适应 x 的维度,这样就只适用于 x 的维度为 [enc_seq_len, enc_features_size] 的情况

于是我看别人是怎么写的

towardsdatascience.com/how-to-make…

他的位置编码层就适应地很好

import torch
import torch.nn as nn 
import math
from torch import nn, Tensor

class PositionalEncoder(nn.Module):
    """
    The authors of the original transformer paper describe very succinctly what 
    the positional encoding layer does and why it is needed:
    
    "Since our model contains no recurrence and no convolution, in order for the 
    model to make use of the order of the sequence, we must inject some 
    information about the relative or absolute position of the tokens in the 
    sequence." (Vaswani et al, 2017)
    Adapted from: 
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    """

    def __init__(
        self, 
        dropout: float=0.1, 
        max_seq_len: int=5000, 
        d_model: int=512,
        batch_first: bool=False
        ):

        """
        Parameters:
            dropout: the dropout rate
            max_seq_len: the maximum length of the input sequences
            d_model: The dimension of the output of sub-layers in the model 
                     (Vaswani et al, 2017)
        """

        super().__init__()

        self.d_model = d_model
        
        self.dropout = nn.Dropout(p=dropout)

        self.batch_first = batch_first

        # adapted from PyTorch tutorial
        position = torch.arange(max_seq_len).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        
        if self.batch_first:
            pe = torch.zeros(1, max_seq_len, d_model)
            
            pe[0, :, 0::2] = torch.sin(position * div_term)
            
            pe[0, :, 1::2] = torch.cos(position * div_term)
        else:
            pe = torch.zeros(max_seq_len, 1, d_model)
        
            pe[:, 0, 0::2] = torch.sin(position * div_term)
        
            pe[:, 0, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, enc_seq_len, dim_val] or 
               [enc_seq_len, batch_size, dim_val]
        """
        if self.batch_first:
            x = x + self.pe[:,:x.size(1)]
        else:
            x = x + self.pe[:x.size(0)]

        return self.dropout(x)

使用方法:

from torch import nn
import positional_encoder as pe

class TransformerTS(nn.Module):
    def __init__(self,
                 enc_features_size,
                 dec_features_size,
                 d_model=512,
                 nhead=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation='relu',
                 custom_encoder=None,
                 custom_decoder=None,
                 batch_first=False):
        super(TransformerTS, self).__init__()

        self.transform = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            custom_encoder=custom_encoder,
            custom_decoder=custom_decoder,
            batch_first=batch_first
        )
        self.positional_encoding_layer = pe.PositionalEncoder(
            d_model=d_model,
            dropout=dropout,
            batch_first=batch_first
        )
        self.enc_input_fc = nn.Linear(enc_features_size, d_model)
        self.dec_input_fc = nn.Linear(dec_features_size, d_model)
        self.out_fc = nn.Linear(d_model, dec_features_size)

这里就是展示一下怎么创建这个位置编码层

怎么 forward 的话,感觉因人而异,反正我就是直接输入一个线性层输出一个线性层

那么 positional_encoding_layer 就在输入 embedding 之后

def forward(self, enc_input, dec_input, src_mask, tgt_mask):

    # print('enc_input.size(): {}'.format(enc_input.size()))

    # embed_encoder_input: [enc_seq_len, 1, enc_features_size] -> [enc_seq_len, 1, d_model]
    embed_encoder_input = self.enc_input_fc(enc_input)

    # print('embed_encoder_input.size(): {}'.format(embed_encoder_input.size()))

    embed_encoder_input = self.positional_encoding_layer(embed_encoder_input)

    # print('embed_encoder_input.size(): {}'.format(embed_encoder_input.size()))

    # embed_decoder_input: [dec_seq_len, 1, dec_features_size] -> [dec_seq_len, 1, d_model]
    embed_decoder_input = self.dec_input_fc(dec_input)

    # x: [dec_seq_len, 1, d_model]
    x = self.transform(src=embed_encoder_input,
                       tgt=embed_decoder_input,
                       src_mask=src_mask,
                       tgt_mask=tgt_mask)

    # x: [dec_seq_len, 1, d_model] -> [dec_seq_len, 1, dec_features_size]
    x = self.out_fc(x)

    return x