transformer各层向量维度变化及注意力机制代码实现

372 阅读1分钟

transformer的维度变化主要发生在embedding层、multi-head attention层和前馈网络(FFN)层。
假设在进入embedding层前,张量的输入维度是XRbatchsizeseqlenX\in R^{batchsize*seqlen}。其中,batch_size是批量大小,即一次训练的样本数;seq_len是序列长度(统一长度,不够的padding,多的进行截断)

embedding layer

embedding层的维度是ERVdmodelE\in R^{V*d_{model}},其中,V是词汇表大小,dmodeld_{model}是词嵌入维度。通过索引查找,输出可以变为XRbatchsizeseqlendmodelX\in R^{batchsize*seqlen*d_{model}}。这里贴上ChatGPT关于索引查找的矩阵乘法形式表示:

image.png

image.png

multi-head attention layer(重要)

输入的张量维度是XRbatchsizeseqlendmodelX\in R^{batchsize*seqlen*d_{model}}
将其线性变换为Q、K、V之后,张量维度不变。
分割成h个头后,张量变为:XRbatchsizeseqlenhdkX\in R^{batchsize*seqlen*h*d_k} ,其中dk=dmodel/hd_k=d_{model}/h
计算注意力:XRbatchsizeseqlendkhX\in R^{batchsize*seqlen*d_k*h}
头拼接:XRbatchsizeseqlendmodelX\in R^{batchsize*seqlen*d_{model}}

前馈网络层

XRbatchsizeseqlendffX\in R^{batchsize*seqlen*d_{ff}}
非线性激活后再映射回:XRbatchsizeseqlendmodelX\in R^{batchsize*seqlen*d_{model}}

注意力机制代码实现

def self_attention(query,key,value,dropout=None,mask=None):
    d_k=query.size(-1)
    scores=torch.matmul(query,key.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
        mask.cuda()
        scores=scores.masked_fill(mask==0,-1e9)
    self_attn=F.softmax(scores,dim=-1)
    if dropout is not None:
        self_attn=dropout(self_attn)
    return torch.matmul(self_attn,value),self_attn

class MultiHeadAttention(nn.Module):
    def __init__(self,head,d_model,dropout=0.1,mask=None):
        super(MultiHeadAttention,self).__init__()
        assert d_model % head == 0 
        self.head=head
        self.d_model=d_model
        self.d_k=d_model // head
        self.linear_query=nn.Linear(d_model,d_model)
        self.linear_key=nn.Linear(d_model,d_model)
        self.linear_value=nn.Linear(d_model,d_model)
        self.linear_out=nn.Linear(d_model,d_model)
        self.attn=None
        self.dropout=nn.Dropout(dropout)
    def forward(self,query,key,value,mask=None):
        if mask is not None:
            mask=mask.unsqueeze(1)
        batch=query.size(0)
        query=self.linear_query(query).view(batch,-1,self.head,self.d_k).transpose(1,2)
        key=self.linear_key(key).view(batch,-1,self.head,self.d_k).transpose(1,2)
        value=self.linear_value(value).view(batch,-1,self.head,self.d_k).transpose(1,2)
        x,self.attn=self_attention(query,key,value,dropout=self.dropout,mask=None)
        x=x.transpose(1,2).contiguous().view(batch,-1,self.head,self.d_k)
        return self.linear_out(x)