transformer原理及代码实践

277 阅读5分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

参考链接和文献: jalammar.github.io/illustrated… blog.csdn.net/qq_28168421…

1. transformer的宏观结构解析

transformer的一种典型的seq2seq结构,常用于序列到序列的应用中。以机器翻译为例,最顶层的视图为: 在这里插入图片描述

将结构放大一些,可以看到是input流经encoder部分再流经decoder部分最后得到output: 在这里插入图片描述

而encoder部分是由6层EncoderLayer组成,decoder部分也是由6层DecoderLayer组成,最后一层EncoderLayer会影响每一层DecoderLayer的输入: 在这里插入图片描述

每层EncoderLayer的结构都是相同的,由一层self-attention和一层FeedForward(以下简称fc)层组成: 在这里插入图片描述

其中self-attention能够使得在编码某个词时,看到并结合句子里面每个词的信息。self-attention的原理解释见blog.csdn.net/THUChina/ar…

DecoderLayer部分也有类似与EncoderLayer中的self-attention层和fc层,但是在这两者之间还有一个encoder-decoder attention层,用于获取原始输入句子中的相关部分。如下图所示: 在这里插入图片描述

2. 数据的流动

2.1. Encoder部分

对于一个序列(比如一个句子),首先要对序列中的每个元素(例如每个词)进行embedding,得到一个embedding向量(具体embedding的方法见后文): 在这里插入图片描述 在transformer里面,默认每个词的embedding向量的维度是512维。也即每个词会得到一个512维的向量表示。

于是self-attention的输入就是一个包含多个词的embedding向量(用x表示)的list。输入可以限制一个最长长度,通常是训练集里面最长的句子长度。而self-attention的输出也是一个list,里面包含了每个词经过该层self-attention后的representation vector(用z表示): 在这里插入图片描述 然后每个词的z向量,又经过一层全连接层。 此外,分别对每个encoder层中的self-attention层和fc层的输入和输出通过残差层连接,并且对输出再做一个layernorm。也即每个子层的输出可以表示为:LayerNorm(x+SubLayer(x))LayerNorm(x+SubLayer(x))。其中xx代表该子层的输入,SubLayerSubLayer代表该子层原始的实现函数(self-attention或者fc): 在这里插入图片描述 同时,为了方便残差连接的实现,每个encoder层的输入输出维度保持一致(512维),这样就不需要在残差连接里面对xx进行升维或者降维。

2.2. Decoder部分

注意,在Encoder部分中,无论输入序列中元素个数有多少,前向过程都是只进行一次并行地推理,这也是self-attention相比RNN/LSTM的优势;而Decoder部分中,必须进行串行地循环推理,逐个生成新的输出元素(单词)。 在这里插入图片描述

注意,decoder层的第一个self-attention的计算过程,与encoder层中的self-attention的计算过程(query,key,value的计算都是依赖于上一层的输出,具体计算过程参见blog.csdn.net/THUChina/ar… ),而第二个self-attention中的key和value来自encoder的输出计算得到,query来自上一个子层的输出计算得到。可以这么形象地理解,第二个self-attention层的功能是利用解码器已经预测出的信息作为query,去编码器提取的各种特征中,查找相关信息并融合到当前特征中,来完成预测。

3. 代码实现(更新中)

3.0. 工具类和工具函数

将一个模块复制N份,注意要深拷贝:

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

layernorm计算:

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

确保decoder计算只依赖历史信息的mask:

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

3.1. Attention计算

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

3.2. EncoderLayer+残差连接+layernorm

每层EncoderLayer的结构都是相同的,由一层self-attention和一层FeedForward(以下简称fc)层组成,其中self-attention和fc统一称作sublayer


class Encoder(nn.Module):
class SublayerConnection(nn.Module):# 其实就是残差+layernorm
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    Comments: That is to say, the norm op is executed when the last output is being feeded to the current layer instea of being generated from last layer.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):#这里是sublayer是self-attn或者fc,对应下文的lambda表达式
        "Apply residual connection to any sublayer with the same size."
        """x+Dropout(SubLayer(LayerNorm(x)))。先对上一层的输出x做layernorm,然后进行sublayer计算(attention或者Feedforward),然后dropout,然后残差连接"""
        return x + self.dropout(sublayer(self.norm(x)))#真正的self-attn或者fc的计算发生在这里
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)#这里的sublayer包含了self-attn或fc(实际的计算过程在上面SublayerConnection.forward中)以及对应的残差连接和layernorm
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        """Outputs of Multi-Head Attentin """
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

3.3. DecoderLayer+残差连接+layernorm

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

3.4. Encoder部分和Decoder部分

class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)