本文同步发布微信公众号:阿黎投喂舍
导读
这应该算是一篇复习的文章,在Transformer几乎要统治各个领域的时候,回想Transformer竟然有一丝模糊。但是Transformer在深度学习领域的地位应该是毋庸置疑的,无论为NLP领域的Bert还是CV领域的Visual Transformer,我们都可以看到它的身影。因此深入了解Tranformer还是很有必要的,这篇文章,我们会以回顾的形式,介绍Transformer的结构以及实现的一些细节。
Bert系列相关文章推荐
What does BERT learn: 探究BERT为什么这么强
TinyBert:超细节应用模型蒸馏,有关蒸馏的疑问看他就够了
[论文分享] | RoBERTa:喂XLNet在吗,出来挨打了
Attention机制
Transformer是一个encode-decode结构,encoder和decoder的结构类似,都是由多个相同的layer拼接起来的。encoder每个layer分为两个sub-layer分别是attention层和全连接层。decoder中每个layer分为三个sub-layer,分别是两个attention层和一个全连接层。
- 单层attention attention层的结构如下图所示,计算方式为,这一步是为了计算Q和K之间的相关性,通过这步计算,上下文之间比较有相关性的单词之间的会根据相关性的权重计算出来。在Self-attention中,QKV的矩阵来源均是同一输入。
用torch实现单层attention如下。
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
de __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output, attn
- MultiHead Attention
Multihead attention的结构如下图所示,就是将单层attention的结果拼起来,这里不再详细解释。
Transformer结构详解
以上就是attention的结构,这个章节我们将为大家介绍Transformer的结构。首先为大家介绍position embedding。
我们在上文为大家介绍了Transformer是encoder-decoder结构,encoder的输入是不仅包含了word embedding还有position embedding。加入position embedding的原因是为了使用输入的文本的序列信息。position embedding的计算方式如下: position embedding
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
1. Encoder
encoder的输入是position embedding和word embedding,他经过multihead attention和全连接网络,将结果输出给decoder。
class EncoderLayer(nn.Module):
''' Compose with two layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn
2. Decoder
decoder有两层attention结构,第一层是self attention,第二层的输入是encoder的输出和self attention的输出。为了防止decoder学习到当前位置之后的东西,在self-attention步骤需要对输入进行mask。第二层attention的mask和encoder的mask相同。
class DecoderLayer(nn.Module):
''' Compose with three layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(DecoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(
self, dec_input, enc_output,
slf_attn_mask=None, dec_enc_attn_mask=None):
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn
结论和思考
这篇文章与其说是对Transformer的解读,不如说是对Transformer的源码解读,在阅读论文的时候有一些地方只是大概看了一下,但是当真的看源码才发现有很多细节当时没有考虑到。可能这就是纸上得来终觉浅吧。 下面是一些思考
- Decoder的输入为什么还要加上原始的输入?
- 为什么要加mask?可以不加吗?
Reference
- Rush A . The Annotated Transformer[C]// Proceedings of Workshop for NLP Open Source Software (NLP-OSS). 2018.
- Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. arXiv preprint arXiv:1706.03762, 2017.