Transformer位置感知前馈网络

239 阅读1分钟

简介

PositionWiseFeedForward类定义了一个位置感知的前馈神经网络,由两个带有ReLU激活函数的线性层组成。在Transformer模型的上下文中,这个前馈网络分别且相同地应用于每个位置。它有助于转换由注意力机制在Transformer中学习到的特征,作为注意力输出的额外处理步骤。

全部代码

class PositionWiseFeedForward(nn.Module):  
    def __init__(self, d_model, d_ff):  
        super(PositionWiseFeedForward, self).__init__()  
        self.fc1 = nn.Linear(d_model, d_ff)  
        self.fc2 = nn.Linear(d_ff, d_model)  
        self.relu = nn.ReLU()  
  
    def forward(self, x):  
        return self.fc2(self.relu(self.fc1(x)))  

类定义/初始化

  1. d_model: 模型输入和输出的维度。
  2. d_ff: 前馈网络中内层的维度。
  3. self.fc1和self.fc2: 两个全连接(线性)层,输入和输出维度由d_model和d_ff定义。
  4. self.relu: ReLU(修正线性单元)激活函数,在两个线性层之间引入非线性。
class PositionWiseFeedForward(nn.Module):  
    def __init__(self, d_model, d_ff):  

前向方法

  1. x: 前馈网络的输入。
  2. self.fc1(x): 输入首先通过第一个线性层(fc1)。
  3. self.relu(...): fc1的输出然后通过ReLU激活函数。ReLU将所有负值替换为零,为模型引入非线性。
  4. self.fc2(...): 激活的输出然后通过第二个线性层(fc2),产生最终输出。
def forward(self, x):  
    return self.fc2(self.relu(self.fc1(x)))