transformer实战——点积注意力(Dot-Product Attention)到底是怎么执行的?

156 阅读6分钟

1 点积注意力(Dot-Product Attention)

1.1 输入

  1. 对于输入的句子 X,通过 WordEmbedding 得到该句子中每个字的字向量,同时通过 Positional Encoding 得到所有字的位置向量,将其相加(维度相同,可以直接相加),得到该字真正的向量表示。
  2. 整个矩阵 X 代表一个句子(或一个批量(batch)的句子),不是每一行不是代表一个句子。所以说矩阵中的每一行代表的是句子中的一个组成部分,输入句子中的第 t 个字的向量记作 xtx^t

1.2 单个字的计算attention的流程

1.2.1 定义矩阵

  1. 我们定义三个矩阵 WQW^QWKW^KWVW^V ,使用这三个矩阵分别对所有的字向量进行三次线性变换,于是所有的字向量又衍生出三个新的向量 qtq^tktk^tvtv^t
  2. 我们将所有的 qtq^t向量拼成一个大矩阵,记作查询矩阵Q,将所有的 ktk^t向量拼成一个大矩阵,记作键矩阵K,将所有的 vtv^t向量拼成一个大矩阵,记作值矩阵 V

1.2.2第一个字向量举例计算第一个字自注意力输出的流程

  1. 先计算所有字向量的 qtq^tktk^tvtv^t image.png
  2. 在计算 q1q^1k1k^1,k2k^2,k2k^2的点积得出分数score1-3 image.png
  3. 把score1-3的值经过softmax,使得它们的和为1,得到权重(softmax([2, 4, 4]) = [0.0, 0.5, 0.5]) image.png
  4. 权重分别乘到对应的 值向量vtv^timage.png
  5. 最后将权重化后的值向量求和,得到第一个字的输出,计算的结果是第一个字(词元)在考虑了整个输入序列上下文信息后,得到的一个新的、富含上下文信息的向量表示。这个新向量可以看作是第一个字基于整个句子的“注意力加权和”,它反映了第一个字与句子中所有其他字(包括它自己)的关系强度,并据此融合了它们的信息。 image.png

1.2.3 使用矩阵计算来一次性计算出所有字向量的Q,K,V。

image.png

  • 输入是一个矩阵 X
    • 矩阵第 t 行为第个字的向量表示 xtx^t。例如,如果输入句子有 2个词,那么 X 矩阵就会有 2 行。
    • 每一行的宽度代表词嵌入的维度。 图中 X 矩阵的每一行有 4 个小方块,这示意性地表示了词嵌入的维度。维度在实际中可能是 512。
  • 三个矩阵 WQW^QWKW^KWVW^V是可学习的线性变换矩阵。它们的作用是将原始的词嵌入 X 投影到不同的空间,生成 Q,K,V。
    • 矩阵的行数必须等于 X 的列数 (词嵌入维度)。 例如,如果 X 的词嵌入维度是 512,那么 WQW^QWKW^KWVW^V 都必须有 512 行。
    • 矩阵的列数是dkd^kdvd^v,决定了 Q,K,V 的维度,即 dkd^k (Q和K的维度), dvd^v (V向量的维度)。
    • dkd^kdvd^v是超参数,在设计和配置模型时,需要自己选择它们的值。
    • 总结矩阵 WQW^QWKW^KWVW^V的 输入维度 是原始词嵌入的维度(通常是 dmodeld^model),而它们的 输出维度(列数) 就是dkd^kdvd^v
  • Q,K,V是 X 经过线性变换后得到的矩阵。
    • 行数与 X 的行数相同 (句子长度)。 例如,如果 X 有 2 行(2个词),那么 Q,K,V 也会有 2 行。
    • 它们的列数与对应的权重矩阵的列数相同。Q,K 的维度是 (句子长度, dkd^k),V 的维度是 (句子长度, dvd^v)

1.2.4 代码计算Q,K,V步骤

  1. 输入矩阵X
    • 在代码中,为了提高计算效率,我们很少一个句子一个句子地输入模型。相反,我们会将多个句子(或批次)打包在一起,即经常使用batch_size个句子一起作为encode一次的输入,所以下X的大小从[src_len,dmodel]->[batch_size,src_len,dmodel],即有多个句子。
      • batch_size:代表一个批次中包含的句子数量。
      • src_len:代表每个句子的长度(词元数量)。
      • d_model:代表每个词元向量的维度。
  2. 创建可训练的 WQW^QWKW^KWVW^V矩阵
    1. 在 PyTorch 这样的深度学习框架中,nn.Linear(in_features, out_features) 是用来表示线性变换的常见方式。
      • 它会创建一个执行线性变换的对象,其操作可以概括为:输出 = 输入 * 权重矩阵的转置 + 偏置向量。
      • in_features 对应输入向量的维度。out_features 对应输出向量的维度。
      • 权重矩阵的形状为[out_features,in_features]
    2. 矩阵的建立
      • W_Q = nn.Linear(d_model, d_k);即内部WQW^Q的形状为[d_k,d_model]
      • W_k = nn.Linear(d_model, d_k);即内部WKW^K的形状为[d_k,d_model]
      • W_v = nn.Linear(d_model, d_v);即内部WVW^V的形状为[d_v,d_model]
  3. 计算Q,K,V
    • Q = W_Q(X)
    • K = W_K(X)
    • V = W_V(X)
  4. 案例
    1. 假设 X 为[2,5,512]
    2. W_Q = nn.Linear(512, 64)
    3. nn.Linear会自动将[2,5,512]视为一个包含2*5=10个独立的512维向量的机会,然后对每个512维的向量做output=input * WQTW_Q^T image.png
    4. 最终,Q的输出的形状会是 [batch_size, sequence_length, 64]

2 缩放点积注意力 (Scaled Dot-Product Attention)

2.1 解释

  1. 缩放点积注意力”(Scaled Dot-Product Attention)是在点积注意力的基础上,额外增加了一个除以缩放因子dk\sqrt{d_k}
  2. dkd^k是Q和K的维度。
  3. 公式如下
    • 关于 Q, K, V 矩阵的行 image.png
    • 关于 score 的行 image.png
    • 关于 weight的行

image.png * 关于 Z 矩阵的行 image.png 4. Z 是通过自注意力机制学习到的,每个词元在考虑了所有其他词元(包括它自己)的上下文信息后的新表示。

2.2代码实现

import torch
import torch.nn.functional as F
import math

def attention(query, key, value, mask=None, dropout=None):
    # 输入变量及其形状:
    # query:  形状通常为 (batch_size, sequence_length, d_k)
    # key:    形状通常为 (batch_size, sequence_length, d_k)
    # value:  形状通常为 (batch_size, sequence_length, d_v)
    # mask:   形状通常为 (batch_size, query_seq_len, key_seq_len)

    # 获取查询向量的最后一个维度,即 d_k
    # query.size(-1) 返回张量的最后一个维度的大小。标量,例如 64 或 512
    d_k = query.size(-1) 
    
    # 步骤1:计算 Q K^T 并进行缩放
    #  key 的形状: (b, s, d),key.transpose(-2, -1) 的形状: (b, d, s),实现 K 的转置
    #  torch.matmul(query, key.transpose(-2, -1)) 执行矩阵乘法 Q K^T
    #  query 的形状: (b, s, d)
    #  scores 的形状: (b, s, s)  # (batch_size, query_seq_len, key_seq_len)
    #  scores[i, j, k] 表示批次i中第j个查询与第k个键的点积分数
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)

    # 步骤2: 应用 Softmax 函数
    # F.softmax(scores, dim = -1) 对 scores 的最后一个维度(即 key 的维度)进行 softmax 归一化。
    # 这确保了对于每个查询(scores 的每一行),其所有注意力权重之和为 1。
    weights = F.softmax(scores, dim = -1)
    # weights 的形状: (b, s, s) # (batch_size, query_seq_len, key_seq_len)
    # weights[i, j, k] 表示批次i中第j个查询对第k个键的注意力权重


    # 步骤 3: 将注意力权重与 Value 矩阵相乘
    # torch.matmul(weights, value) 执行加权求和.结果的形状: (b, s, d_v).它代表了输入序列中每个词元经过上下文加权后的新向量表示。
    # weights 是注意力权重矩阵。它直接显示了每个查询词元对序列中其他所有键词元的“关注”程度。在模型分析和可解释性方面非常重要
    return torch.matmul(p_attn, value), weights

参考

  1. Transformer 详解
  2. The Illustrated Transformer
  3. The Annotated Transformer