Transformer位置编码

357 阅读1分钟

简介

位置编码用于注入输入序列中每个标记的位置信息。它使用不同频率的正弦和余弦函数生成位置编码。

PositionalEncoding类添加了有关序列中标记位置的信息。由于Transformer模型缺乏对令牌顺序的固有知识(由于其自注意力机制),此类帮助模型考虑序列中令牌的位置。所选的正弦函数允许模型轻松学习注意相对位置,因为它们为序列中的每个位置产生独特且平滑的编码。

全部代码

class PositionalEncoding(nn.Module):  
    def __init__(self, d_model, max_seq_length):  
        super(PositionalEncoding, self).__init__()  
  
        pe = torch.zeros(max_seq_length, d_model)  
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)  
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))  
  
        pe[:, 0::2] = torch.sin(position * div_term)  
        pe[:, 1::2] = torch.cos(position * div_term)  
  
        self.register_buffer('pe', pe.unsqueeze(0))  
  
    def forward(self, x):  
        return x + self.pe[:, :x.size(1)]  

类定义/初始化

  1. d_model: 模型输入的维度。
  2. max_seq_length: 预先计算位置编码的最大序列长度。
  3. pe: 用零填充的张量,将填充位置编码。
  4. position: 包含序列中每个位置的位置索引的张量。
  5. div_term: 用于以特定方式缩放位置索引的项。
  6. 正弦函数应用于pe的偶数索引,余弦函数应用于奇数索引。
  7. 最后,将pe注册为缓冲区,这意味着它将是模块状态的一部分,但不会被视为可训练参数。
class PositionalEncoding(nn.Module):  
    def __init__(self, d_model, max_seq_length):  
        super(PositionalEncoding, self).__init__()  

前向方法

前向方法简单地将位置编码添加到输入x。它使用pe的前x.size(1)个元素,以确保位置编码与x的实际序列长度相匹配。

def forward(self, x):  
    return x + self.pe[:, :x.size(1)]