详解Transformer网络结构

0 阅读6分钟

1. Transformer 整体结构

image.png

Transformer 由 Encoder 和 Decoder 两个部分组成,Encoder 和 Decoder 都包含 6 个 block。Transformer 的工作流程大体如下:

第一步: 获取输入句子的每一个单词的表示向量 XX由单词的 Embedding(Embedding就是从原始数据提取出来的Feature) 和单词位置的 Embedding 相加得到。

第二步: 将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x) 传入 Encoder 中,经过 6 个 Encoder block 后可以得到句子所有单词的编码信息矩阵 C

第三步:将 Encoder 输出的编码信息矩阵 C传递到 Decoder 中,Decoder 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖)  操作遮盖住 i+1 之后的单词。

2. Transformer 的输入

Transformer 中单词的输入表示 x单词 Embedding 和位置 Embedding (Positional Encoding)相加得到。

2.1 单词 Embedding

单词的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。

2.2 位置 Embedding

Transformer 中除了单词的 Embedding,还需要使用位置 Embedding 表示单词出现在句子中的位置。因为 Transformer 不采用 RNN 的结构,而是使用全局信息,不能利用单词的顺序信息,而这部分信息对于 NLP 来说非常重要。 所以 Transformer 中使用位置 Embedding 保存单词在序列中的相对或绝对位置。

3. Self-Attention(自注意力机制)

上图是论文中 Transformer 的内部结构图,左侧为 Encoder block,右侧为 Decoder block。红色圈中的部分为 Multi-Head Attention,是由多个 Self-Attention组成的,可以看到 Encoder block 包含一个 Multi-Head Attention,而 Decoder block 包含两个 Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接用于防止网络退化,Norm 表示,用于对每一层的激活值进行归一化。

3.1 Self-Attention 结构

image.png

上图是 Self-Attention 的结构,在计算的时候需要用到矩阵Q(查询),K(键值),V(值) 。在实际中,Self-Attention 接收的是输入(单词的表示向量x组成的矩阵X) 或者上一个 Encoder block 的输出。而Q,K,V正是通过 Self-Attention 的输入进行线性变换得到的。

Self- Attention核心概念Q、K、V三剑客是什么?

Self-Attention的精髓在于三个矩阵:Query(查询)、Key(键)、Value(值)。计算Query和Key的相似度,返回相关Value。

(1)Query(Q),我想找什么?对应的就是查询问题。 类似搜索引擎中,你输入搜索词 → Query。

(2)Key(K),我能提供什么?对应的就是结果的索引。 类似搜索引擎中,网页的标题和关键词 → Key。

(3)Value(V) ,我实际包含的内容具体值。 类似搜索引擎中,网页的实际内容 → Value。

3.2 Q, K, V 的计算

Self-Attention 的输入用矩阵X进行表示,则可以使用线性变阵矩阵WQ,WK,WV计算得到Q,K,V。计算如下图所示,注意 X, Q, K, V 的每一行都表示一个单词。

56f8b257-42ac-4b9c-acd5-6648542bf95e.png

3.3 Self-Attention 的输出

得到矩阵 Q, K, V之后就可以计算出 Self-Attention 的输出了,计算的公式如下:

127c6138-76fc-4f38-8684-e3a99d2d7c0f.png

公式中计算矩阵QK每一行向量的内积,为了防止内积过大,因此除以dk的平方根。Q乘以K的转置后,得到的矩阵行列数都为 n,n 为句子单词数,这个矩阵可以表示单词之间的 attention 强度。下图为Q乘以 KT ,1234 表示的是句子中的单词。

c735f9ba-26c1-4ae2-88e8-40995a3ddfba.png

得到 QKT之后,使用 Softmax 计算每一个单词对于其他单词的 attention 系数,公式中的 Softmax 是对矩阵的每一行进行 Softmax,即每一行的和都变为 1.

9144b6d5-85f1-4bd8-8788-02a84c038988.png

得到 Softmax 矩阵之后可以和V相乘,得到最终的输出Z

1ce1c932-2fab-4333-8ea8-a840e7ab8217.png

上图中 Softmax 矩阵的第 1 行表示单词 1 与其他所有单词的 attention 系数,最终单词 1 的输出Z1  等于所有单词 i 的值 Vi 根据 attention 系数的比例加在一起得到,如下图所示:

ef2d9049-ee11-492f-aa45-226cd0f70746.png

3.4 Multi-Head Attention

Multi-Head Attention 是由多个 Self-Attention 组合形成的

30fc38a1-fd17-45fb-b0eb-37a254b31fab.png

从上图可以看到 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入X分别传递到 h 个不同的 Self-Attention 中,计算得到 h 个输出矩阵Z。下图是 h=8 时候的情况,此时会得到 8 个输出矩阵Z

bb6475eb-1f5f-4f59-b874-96550e4d63c5.png

得到 8 个输出矩阵Z1  到 Z8 之后,Multi-Head Attention 将它们拼接在一起  (Concat) ,然后传入一个Linear层,得到 Multi-Head Attention 最终的输出Z

4167cace-9514-493d-804b-a816e80a6e6c.png

可以看到 Multi-Head Attention 输出的矩阵Z与其输入的矩阵X的维度是一样的。

4. Encoder 编码器结构

Transformer 的 Encoder block 结构,可以看到是由 Multi-Head Attention, Add & Norm, Feed Forward, Add & Norm 组成的。

4.1 Add & Norm

Add & Norm 层由 Add 和 Norm 两部分组成,其计算公式如下:

ed13ed4f-067e-4280-8973-e665bfb111a0.png

其中 X表示 Multi-Head Attention 或者 Feed Forward 的输入,MultiHeadAttention(X) 和 FeedForward(X) 表示输出 (输出与输入 X 维度是一样的,所以可以相加)。

Add指 X+MultiHeadAttention(X),是一种残差连接,通常用于解决多层网络训练的问题,可以让网络只关注当前差异的部分,在 ResNet 中经常用到。

Norm指 Layer Normalization,通常用于 RNN 结构,Layer Normalization 会将每一层神经元的输入都转成均值方差都一样的,这样可以加快收敛

4.2 Feed Forward

Feed Forward 层比较简单,是一个两层的全连接层,第一层的激活函数为 Relu,第二层不使用激活函数,对应的公式如下。

aa4fa4da-e128-46f9-ad44-fa6afa5b8ba3.png

X是输入,Feed Forward 最终得到的输出矩阵的维度与X一致。

4.3 组成 Encoder

第一个 Encoder block 的输入为句子单词的表示向量矩阵,后续 Encoder block 的输入是前一个 Encoder block 的输出,最后一个 Encoder block 输出的矩阵就是编码信息矩阵 C,这一矩阵后续会用到 Decoder 中。

  • 包含两个 Multi-Head Attention 层。
  • 第一个 Multi-Head Attention 层采用了 Masked 操作。
  • 第二个 Multi-Head Attention 层的K, V矩阵使用 Encoder 的编码信息矩阵C进行计算,而Q使用上一个 Decoder block 的输出计算。
  • 最后有一个 Softmax 层计算下一个翻译单词的概率。

5. Decoder 结构

5.1 Masked Multi-Head Attention

5.2 Multi-Head Attention

5.3 Softmax 预测输出单词

参考文献:

Transformer模型详解(图解最完整版)

从零理解Transformer:原理、架构与PyTorch逐行实现

Transformer各层网络结构详解!面试必备!(附代码实现)_transformer网络结构

细节拉满,全网最详细的Transformer介绍(含大量插图)!

一文搞定自注意力机制(Self-Attention)