本文已参与「新人创作礼」活动,一起开启掘金创作之路。
参考链接和文献: jalammar.github.io/illustrated… blog.csdn.net/qq_28168421…
1. transformer的宏观结构解析
transformer的一种典型的seq2seq结构,常用于序列到序列的应用中。以机器翻译为例,最顶层的视图为:
将结构放大一些,可以看到是input流经encoder部分再流经decoder部分最后得到output:
而encoder部分是由6层EncoderLayer组成,decoder部分也是由6层DecoderLayer组成,最后一层EncoderLayer会影响每一层DecoderLayer的输入:
每层EncoderLayer的结构都是相同的,由一层self-attention和一层FeedForward(以下简称fc)层组成:
其中self-attention能够使得在编码某个词时,看到并结合句子里面每个词的信息。self-attention的原理解释见blog.csdn.net/THUChina/ar…
DecoderLayer部分也有类似与EncoderLayer中的self-attention层和fc层,但是在这两者之间还有一个encoder-decoder attention层,用于获取原始输入句子中的相关部分。如下图所示:
2. 数据的流动
2.1. Encoder部分
对于一个序列(比如一个句子),首先要对序列中的每个元素(例如每个词)进行embedding,得到一个embedding向量(具体embedding的方法见后文):
在transformer里面,默认每个词的embedding向量的维度是512维。也即每个词会得到一个512维的向量表示。
于是self-attention的输入就是一个包含多个词的embedding向量(用x表示)的list。输入可以限制一个最长长度,通常是训练集里面最长的句子长度。而self-attention的输出也是一个list,里面包含了每个词经过该层self-attention后的representation vector(用z表示):
然后每个词的z向量,又经过一层全连接层。
此外,分别对每个encoder层中的self-attention层和fc层的输入和输出通过残差层连接,并且对输出再做一个layernorm。也即每个子层的输出可以表示为:。其中代表该子层的输入,代表该子层原始的实现函数(self-attention或者fc):
同时,为了方便残差连接的实现,每个encoder层的输入输出维度保持一致(512维),这样就不需要在残差连接里面对进行升维或者降维。
2.2. Decoder部分
注意,在Encoder部分中,无论输入序列中元素个数有多少,前向过程都是只进行一次并行地推理,这也是self-attention相比RNN/LSTM的优势;而Decoder部分中,必须进行串行地循环推理,逐个生成新的输出元素(单词)。
注意,decoder层的第一个self-attention的计算过程,与encoder层中的self-attention的计算过程(query,key,value的计算都是依赖于上一层的输出,具体计算过程参见blog.csdn.net/THUChina/ar… ),而第二个self-attention中的key和value来自encoder的输出计算得到,query来自上一个子层的输出计算得到。可以这么形象地理解,第二个self-attention层的功能是利用解码器已经预测出的信息作为query,去编码器提取的各种特征中,查找相关信息并融合到当前特征中,来完成预测。
3. 代码实现(更新中)
3.0. 工具类和工具函数
将一个模块复制N份,注意要深拷贝:
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
layernorm计算:
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
确保decoder计算只依赖历史信息的mask:
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
3.1. Attention计算
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
3.2. EncoderLayer+残差连接+layernorm
每层EncoderLayer的结构都是相同的,由一层self-attention和一层FeedForward(以下简称fc)层组成,其中self-attention和fc统一称作sublayer
class Encoder(nn.Module):
class SublayerConnection(nn.Module):# 其实就是残差+layernorm
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
Comments: That is to say, the norm op is executed when the last output is being feeded to the current layer instea of being generated from last layer.
"""
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):#这里是sublayer是self-attn或者fc,对应下文的lambda表达式
"Apply residual connection to any sublayer with the same size."
"""x+Dropout(SubLayer(LayerNorm(x)))。先对上一层的输出x做layernorm,然后进行sublayer计算(attention或者Feedforward),然后dropout,然后残差连接"""
return x + self.dropout(sublayer(self.norm(x)))#真正的self-attn或者fc的计算发生在这里
class EncoderLayer(nn.Module):
"Encoder is made up of self-attn and feed forward (defined below)"
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 2)#这里的sublayer包含了self-attn或fc(实际的计算过程在上面SublayerConnection.forward中)以及对应的残差连接和layernorm
self.size = size
def forward(self, x, mask):
"Follow Figure 1 (left) for connections."
"""Outputs of Multi-Head Attentin """
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
return self.sublayer[1](x, self.feed_forward)
3.3. DecoderLayer+残差连接+layernorm
class DecoderLayer(nn.Module):
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask):
"Follow Figure 1 (right) for connections."
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
return self.sublayer[2](x, self.feed_forward)
3.4. Encoder部分和Decoder部分
class Encoder(nn.Module):
"Core encoder is a stack of N layers"
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
"Pass the input (and mask) through each layer in turn."
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class Decoder(nn.Module):
"Generic N layer decoder with masking."
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)