简介
位置编码用于注入输入序列中每个标记的位置信息。它使用不同频率的正弦和余弦函数生成位置编码。
PositionalEncoding类添加了有关序列中标记位置的信息。由于Transformer模型缺乏对令牌顺序的固有知识(由于其自注意力机制),此类帮助模型考虑序列中令牌的位置。所选的正弦函数允许模型轻松学习注意相对位置,因为它们为序列中的每个位置产生独特且平滑的编码。
全部代码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
类定义/初始化
- d_model: 模型输入的维度。
- max_seq_length: 预先计算位置编码的最大序列长度。
- pe: 用零填充的张量,将填充位置编码。
- position: 包含序列中每个位置的位置索引的张量。
- div_term: 用于以特定方式缩放位置索引的项。
- 正弦函数应用于pe的偶数索引,余弦函数应用于奇数索引。
- 最后,将pe注册为缓冲区,这意味着它将是模块状态的一部分,但不会被视为可训练参数。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super(PositionalEncoding, self).__init__()
前向方法
前向方法简单地将位置编码添加到输入x。它使用pe的前x.size(1)个元素,以确保位置编码与x的实际序列长度相匹配。
def forward(self, x):
return x + self.pe[:, :x.size(1)]