[整理] 聊聊 Transformer

423 阅读14分钟
原文链接: zhuanlan.zhihu.com

Attention is all you need 是一篇将 Attention 思想发挥到极致的论文,出自 Google。这篇论文中提出一个全新的模型,叫 Transformer,抛弃了以往深度学习任务里面使用到的 CNN 和 RNN (其实也不完全是,还是用到了一维卷积。。。打脸)。这个模型广泛应用于NLP领域,例如机器翻译,问答系统,文本摘要和语音识别等等方向。可以说是非常牛逼。

今天,我打算尝试写一下我对 Transformer 的理解。其中内容参考了很多大神的总结。

Transformer 整体框架

如上图所示,咋一看,Transformer 的架构是不是有点复杂。。。没事,下面慢慢讲。。。

和经典的 seq2seq 模型一样,Transformer模型中也采用了 encoer-decoder 架构。上图的左半边用 NX 框出来的,就代表一层 encoder,其中论文里面的 encoder 一共有6层这样的结构。上图的右半边用 NX 框出来的,则代表一层 decoder,同样也有6层。

定义输入序列首先经过 word embedding,再和 positional encoding相加后,输入到 encoder 中。输出序列经过的处理和输入序列一样,然后输入到 decoder。

最后,decoder 的输出经过一个线性层,再接 Softmax。

于上便是 Transformer 的整体框架,下面先来介绍 encoder 和 decoder。

Encoder

encoder由 6 层相同的层组成,每一层分别由两部分组成:

  • 第一部分是 multi-head self-attention
  • 第二部分是 position-wise feed-forward network,是一个全连接层

两个部分,都有一个残差连接(residual connection),然后接着一个 Layer Normalization。

Decoder

和 encoder 类似,decoder 也是由6个相同的层组成,每一个层包括以下3个部分:

  • 第一个部分是 multi-head self-attention mechanism
  • 第二部分是 multi-head context-attention mechanism
  • 第三部分是一个 position-wise feed-forward network

和 encoder 一样,上面三个部分的每一个部分,都有一个残差连接,后接一个 Layer Normalization

decoder 和 encoder 不同的地方在 multi-head context-attention mechanism

Attention

我在以前的文章中讲过,Attention 如果用一句话来描述,那就是 encoder 层的输出经过加权平均后再输入到 decoder 层中。它主要应用在 seq2seq 模型中,这个加权可以用矩阵来表示,也叫 Attention 矩阵。它表示对于某个时刻的输出 y,它在输入 x 上各个部分的注意力。这个注意力就是我们刚才说到的加权。

Attention 又分为很多种,其中两种比较典型的有加性 Attention 和乘性 Attention。加性 Attention 对于输入的隐状态 h_t 和输出的隐状态 s_t 直接做 concat 操作,得到 [s_t; h_t] ,乘性 Attention 则是对输入和输出做 dot 操作。

在 Google 这篇论文中,使用对 Attention 模型是乘性 Attention。

我在之前讲 ESIM 模型的文章里面写过一个 soft-align-attention,大家可以参考体会一下。

Self-Attention

上面我们说attention机制的时候,都会说到两个隐状态,分别是 h_is_t。前者是输入序列第 i个位置产生的隐状态,后者是输出序列在第 t 个位置产生的隐状态。所谓 self-attention实际上就是,输出序列就是输入序列。因而自己计算自己的attention 得分。

Context-Attention

context-attention 是 encoder 和 decoder 之间的 attention,是两个不同序列之间的attention,与来源于自身的 self-attention 相区别。

不管是哪种 attention,我们在计算 attention 权重的时候,可以选择很多方式,常用的方法有

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

Transformer模型采用的是最后一种:scaled dot-product attention。

Scaled Dot-Product Attention

那么什么是 scaled dot-product attention 呢?

Google 在论文中对 Attention 机制这么来描述:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

通过 query 和 key 的相似性程度来确定 value 的权重分布。论文中的公式长下面这个样子:


Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V

看到 Q,K,V 会不会有点晕,没事,后面会解释。

scaled dot-product attention 和 dot-product attention 唯一的区别就是,scaled dot-product attention 有一个缩放因子, 叫\frac{1}{\sqrt{d_k}}d_k 表示 Key 的维度,默认用 64。

论文里对于 d_k 的作用这么来解释:对于 d_k 很大的时候,点积得到的结果维度很大,使得结果处于softmax函数梯度很小的区域。这时候除以一个缩放因子,可以一定程度上减缓这种情况。

scaled dot-product attention 的结构图如下所示。

现在来说下 K、Q、V 分别代表什么:

  • 在 encoder 的 self-attention 中,Q、K、V 都来自同一个地方,它们是上一层 encoder 的输出。对于第一层 encoder,它们就是 word embedding 和 positional encoding 相加得到的输入。
  • 在 decoder 的 self-attention 中,Q、K、V 也是自于同一个地方,它们是上一层 decoder 的输出。对于第一层 decoder,同样也是 word embedding 和 positional encoding 相加得到的输入。但是对于 decoder,我们不希望它能获得下一个 time step (即将来的信息,不想让他看到它要预测的信息),因此我们需要进行 sequence masking。
  • 在 encoder-decoder attention 中,Q 来自于 decoder 的上一层的输出,K 和 V 来自于 encoder 的输出,K 和 V 是一样的。
  • Q、K、V 的维度都是一样的,分别用 d_Qd_Kd_V 来表示

目前可能描述有有点抽象,不容易理解。结合一些应用来说,比如,如果是在自动问答任务中的话,Q 可以代表答案的词向量序列,取 K = V 为问题的词向量序列,那么输出就是所谓的 Aligned Question Embedding。

Google 论文的主要贡献之一是它表明了内部注意力在机器翻译 (甚至是一般的Seq2Seq任务)的序列编码上是相当重要的,而之前关于 Seq2Seq 的研究基本都只是把注意力机制用在解码端。

Scaled Dot-Product Attention 实现

import torch
import torch.nn as nn
import torch.functional as F
import numpy as np

class ScaledDotProductAttention(nn.Module):
    """Scaled dot-product attention mechanism."""

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向传播.
        Args:
        	q: Queries张量,形状为[B, L_q, D_q]
        	k: Keys张量,形状为[B, L_k, D_k]
        	v: Values张量,形状为[B, L_v, D_v],一般来说就是k
        	scale: 缩放因子,一个浮点标量
        	attn_mask: Masking张量,形状为[B, L_q, L_k]

        Returns:
        	上下文张量和attention张量
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
        	attention = attention * scale
        if attn_mask:
        	# 给需要mask的地方设置一个负无穷
        	attention = attention.masked_fill_(attn_mask, -np.inf)
	# 计算softmax
        attention = self.softmax(attention)
	# 添加dropout
        attention = self.dropout(attention)
	# 和V做点积
        context = torch.bmm(attention, v)
        return context, attention

Multi-head attention

理解了 Scaled dot-product attention,Multi-head attention 也很容易理解啦。论文提到,他们发现将 Q、K、V 通过一个线性映射之后,分成 h 份,对每一份进行 scaled dot-product attention 效果更好。然后,把各个部分的结果合并起来,再次经过线性映射,得到最终的输出。这就是所谓的 multi-head attention。上面的超参数 h 就是 heads 的数量。论文默认是 8。

multi-head attention 的结构图如下所示。

值得注意的是,上面所说的分成 h 份是在 d_Qd_Kd_V的维度上进行切分。因此进入到scaled dot-product attention 的 d_K 实际上等于未进入之前的 \frac{D_K}{h}

Multi-head attention 的公式如下:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^0

其中,

head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

在论文里面, d_{model} = 512,h = 8,所以在 scaled dot-product attention 里面的

d_Q = d_K = d_V = d_{model} / h = 512 / 8 = 64

可以看出,所谓 Multi-Head,就是只多做几次同样的事情,同时参数不共享,然后把结果拼接。

Multi-head attention 实现

class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim // num_heads
        self.num_heads = num_heads
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
	
        # multi-head attention之后需要做layer norm
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, key, value, query, attn_mask=None):
	# 残差连接
        residual = query
        dim_per_head = self.dim_per_head
        num_heads = self.num_heads
        batch_size = key.size(0)

        # linear projection
        key = self.linear_k(key)
        value = self.linear_v(value)
        query = self.linear_q(query)

        # split by heads
        key = key.view(batch_size * num_heads, -1, dim_per_head)
        value = value.view(batch_size * num_heads, -1, dim_per_head)
        query = query.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)

        # scaled dot product attention
        scale = (key.size(-1)) ** -0.5
        context, attention = self.dot_product_attention(
          query, key, value, scale, attn_mask)

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)

        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_norm(residual + output)

        return output, attention

上面代码中出现的 Residual connection 我在之前一篇文章中讲过,这里不再赘述,只解释 Layer normalization。

Layer normalization

Normalization有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为0方差为1的数据。我们在把数据送入激活函数之前进行normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区。

说到 normalization,那就肯定得提到 Batch Normalization。

BN的主要思想就是:在每一层的每一批数据上进行归一化。我们可能会对输入数据进行归一化,但是经过该网络层的作用后,我们的数据已经不再是归一化的了。随着这种情况的发展,数据的偏差越来越大,我的反向传播需要考虑到这些大的偏差,这就迫使我们只能使用较小的学习率来防止梯度消失或者梯度爆炸。

BN的具体做法就是对每一小批数据,在批这个方向上做归一化。如下图所示:

可以看到,右半边求均值是沿着数据 batch N 的方向进行的

Batch normalization的计算公式如下:

BN(x_i) = \alpha\times\frac{x_i - \mu_b}{\sqrt{\sigma_B^2 + \epsilon}} + \beta

那么什么是 Layer normalization 呢?它也是归一化数据的一种方式,不过 LN 是在每一个样本上计算均值和方差,而不是BN那种在批方向计算均值和方差

下面是 LN 的示意图:

和上面的 BN 示意图一比较就可以看出二者的区别啦!

下面看一下 LN 的公式:

LN(x_i) = \alpha\times\frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}} + \beta

Mask

mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。Transformer 模型里面涉及两种 mask,分别是 padding mask 和 sequence mask。

其中,padding mask 在所有的 scaled dot-product attention 里面都需要用到,而 sequence mask 只有在 decoder 的 self-attention 里面用到。

Padding Mask

什么是 padding mask 呢?因为每个批次输入序列长度是不一样的也就是说,我们要对输入序列进行对齐。具体来说,就是给在较短的序列后面填充 0。因为这些填充的位置,其实是没什么意义的,所以我们的attention机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。

具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样的话,经过 softmax,这些位置的概率就会接近0!

而我们的 padding mask 实际上是一个张量,每个值都是一个Boolean,值为 false 的地方就是我们要进行处理的地方。

实现:

def padding_mask(seq_k, seq_q):
    # seq_k 和 seq_q 的形状都是 [B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)  # shape [B, L_q, L_k]
    return pad_mask

Sequence mask

文章前面也提到,sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。

那么具体怎么做呢?也很简单:产生一个上三角矩阵,上三角的值全为1,下三角的值权威0,对角线也是0。把这个矩阵作用在每一个序列上,就可以达到我们的目的啦。

具体的代码实现如下:

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

效果如下,

  • 对于 decoder 的 self-attention,里面使用到的 scaled dot-product attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个mask相加作为attn_mask。
  • 其他情况,attn_mask 一律等于 padding mask。

Positional Embedding

现在的 Transformer 架构还没有提取序列顺序的信息,这个信息对于序列而言非常重要,如果缺失了这个信息,可能我们的结果就是:所有词语都对了,但是无法组成有意义的语句。

为了解决这个问题。论文使用了 Positional Embedding:对序列中的词语出现的位置进行编码。

在实现的时候使用正余弦函数。公式如下:

PE(pos, 2i) = sin(pos/10000^{2i/d_{model}})

PE(pos, 2i+1) = cos(pos/10000^{2i/d_{model}})

其中,pos 是指词语在序列中的位置。可以看出,在偶数位置,使用正弦编码,在奇数位置,使用余弦编码

从编码公式中可以看出,给定词语的 pos,我们可以把它编码成一个 d_{model} 的向量。也就是说,位置编码的每一个维度对应正弦曲线,波长构成了从 2\pi10000\times2\pi 的等比数列。

上面的位置编码是绝对位置编码。但是词语的相对位置也非常重要。这就是论文为什么要使用三角函数的原因!

正弦函数能够表达相对位置信息,主要数学依据是以下两个公式:

sin(\alpha + \beta) = sin\alpha cos\beta + cos\alpha sin \beta \\ cos(\alpha + \beta) = cos\alpha cos\beta - sin\alpha sin \beta

上面的公式说明,对于词汇之间的位置偏移 k, PE(pos + k) 可以表示成 PE(pos)PE(k)组合的形式,相当于有了可以表达相对位置的能力。

具体实现如下:

class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, max_seq_len):
        """初始化。
        Args:
            d_model: 一个标量。模型的维度,论文默认是512
            max_seq_len: 一个标量。文本序列的最大长度
        """
        super(PositionalEncoding, self).__init__()
        
        # 根据论文给的公式,构造出PE矩阵
        position_encoding = np.array([
          [pos / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
          for pos in range(max_seq_len)])
        # 偶数列使用sin,奇数列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩阵的第一行,加上一行全是0的向量,代表这`PAD`的positional encoding
        # 在word embedding中也经常会加上`UNK`,代表位置单词的word embedding,两者十分类似
        # 那么为什么需要这个额外的PAD的编码呢?很简单,因为文本序列的长度不一,我们需要对齐,
        # 短的序列我们使用0在结尾补全,我们也需要这些补全位置的编码,也就是`PAD`对应的位置编码
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))
        
        # 嵌入操作,+1是因为增加了`PAD`这个补全位置的编码,
        # Word embedding中如果词典增加`UNK`,我们也需要+1。看吧,两者十分相似
        self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding,
                                                     requires_grad=False)
    def forward(self, input_len):
        """神经网络的前向传播。

        Args:
          input_len: 一个张量,形状为[BATCH_SIZE, 1]。每一个张量的值代表这一批文本序列中对应的长度。

        Returns:
          返回这一批序列的位置编码,进行了对齐。
        """
        
        # 找出这一批序列的最大长度
        max_len = torch.max(input_len)
        tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor
        # 对每一个序列的位置进行对齐,在原序列位置的后面补上0
        # 这里range从1开始也是因为要避开PAD(0)的位置
        input_pos = tensor(
          [list(range(1, len + 1)) + [0] * (max_len - len) for len in input_len])
        return self.position_encoding(input_pos)

Position-wise Feed-Forward network

这是一个全连接网络,包含两个线性变换和一个非线性函数(实际上就是ReLU)。公式如下

FFN = max(0, xW_1 + b_1)W_2 + b_2

这个线性变换在不同的位置都表现地一样,并且在不同的层之间使用不同的参数。

这里实现上用到了两个一维卷积。

实现如下:

class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv1d(ffn_dim, model_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output

Transformer的实现

现在可以开始完成 Transformer 模型的构建了,encoder 端和 decoder 端分别都有 6 层,实现如下,首先是 encoder 端,

class EncoderLayer(nn.Module):
	"""Encoder的一层。"""

    def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, inputs, attn_mask=None):

        # self attention
        context, attention = self.attention(inputs, inputs, inputs, padding_mask)

        # feed forward network
        output = self.feed_forward(context)

        return output, attention


class Encoder(nn.Module):
	"""多层EncoderLayer组成Encoder。"""

    def __init__(self,
               vocab_size,
               max_seq_len,
               num_layers=6,
               model_dim=512,
               num_heads=8,
               ffn_dim=2048,
               dropout=0.0):
        super(Encoder, self).__init__()

        self.encoder_layers = nn.ModuleList(
          [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
           range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_mask = padding_mask(inputs, inputs)

        attentions = []
        for encoder in self.encoder_layers:
            output, attention = encoder(output, self_attention_mask)
            attentions.append(attention)

        return output, attentions

然后是 Decoder 端,

class DecoderLayer(nn.Module):

    def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(DecoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self,
              dec_inputs,
              enc_outputs,
              self_attn_mask=None,
              context_attn_mask=None):
        # self attention, all inputs are decoder inputs
        dec_output, self_attention = self.attention(
          dec_inputs, dec_inputs, dec_inputs, self_attn_mask)

        # context attention
        # query is decoder's outputs, key and value are encoder's inputs
        dec_output, context_attention = self.attention(
          enc_outputs, enc_outputs, dec_output, context_attn_mask)

        # decoder's output, or context
        dec_output = self.feed_forward(dec_output)

        return dec_output, self_attention, context_attention


class Decoder(nn.Module):

    def __init__(self,
               vocab_size,
               max_seq_len,
               num_layers=6,
               model_dim=512,
               num_heads=8,
               ffn_dim=2048,
               dropout=0.0):
        super(Decoder, self).__init__()

        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
          [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
           range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_padding_mask = padding_mask(inputs, inputs)
        seq_mask = sequence_mask(inputs)
        self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)

        self_attentions = []
        context_attentions = []
        for decoder in self.decoder_layers:
            output, self_attn, context_attn = decoder(
            output, enc_output, self_attn_mask, context_attn_mask)
            self_attentions.append(self_attn)
            context_attentions.append(context_attn)

        return output, self_attentions, context_attentions

组合一下,就是 Transformer 模型。

class Transformer(nn.Module):

    def __init__(self,
               src_vocab_size,
               src_max_len,
               tgt_vocab_size,
               tgt_max_len,
               num_layers=6,
               model_dim=512,
               num_heads=8,
               ffn_dim=2048,
               dropout=0.2):
        super(Transformer, self).__init__()

        self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim,
                               num_heads, ffn_dim, dropout)
        self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim,
                               num_heads, ffn_dim, dropout)

        self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, src_seq, src_len, tgt_seq, tgt_len):
        context_attn_mask = padding_mask(tgt_seq, src_seq)

        output, enc_self_attn = self.encoder(src_seq, src_len)

        output, dec_self_attn, ctx_attn = self.decoder(
          tgt_seq, tgt_len, output, context_attn_mask)

        output = self.linear(output)
        output = self.softmax(output)

        return output, enc_self_attn, dec_self_attn, ctx_attn

That's all !

Github 地址: pengshuang/Transformer