大模型第二课:编码注意力机制

13 阅读6分钟

我们在第一课讲述了Transformer的基本原理,这一课,我们用代码实现注意力机制。

注意力机制代码

nn.Module 是PyTorch模型的一个基本构建块,为模型层的创建和管理提供必要的功能。

注意力的本质是全局信息加权和。记得在第一课,我们用Q(query), K(key), V(value)来构建注意力打分。

A(q,K,V)=softmax ⁣(qKTdk)VA(q, K, V) = \mathrm{softmax}\!\left( \frac{q K^{\mathsf T}}{\sqrt{d_k}} \right) V

自注意力机制

下面代码来自[1]: 一个非常简化的自注意类。

为什么叫做自注意力机制? 因为这里Q, K, V 都来自同一个输入x.


import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

## 使用方法如下:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

在介绍因果注意力机制之前,我们优化这个简单的版本。

class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

v2 更好在:参数初始化更合理 + 支持 bias + 更通用的输入维度 + 更工程化易扩展

因果注意力机制

因果注意力 = 只能看当前和过去,不能看未来。

CausalAttention(Q,K,V)=softmax(QKdk+Mask)V\text{CausalAttention}(Q,K,V) = \text{softmax} \left( \frac{QK^\top}{\sqrt{d_k}} + \text{Mask} \right)V
Maskij={0jij>i\text{Mask}_{ij} = \begin{cases} 0 & j \le i \\ -\infty & j > i \end{cases}

由于掩码遮蔽了未来,我们看到tensor看起来这个样子。

✓ 0 0 0
✓ ✓ 0 0
✓ ✓ ✓ 0
✓ ✓ ✓ ✓

下面,我们来实现一个简单的因果注意力机制。

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()

        # 输出特征维度(即 d_k)
        self.d_out = d_out

        # 线性映射:X -> Q, K, V
        # 对应公式:
        # Q = XW_Q
        # K = XW_K
        # V = XW_V
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # 注意力权重的 dropout(Transformer 标准做法)
        self.dropout = nn.Dropout(dropout)

        # 构造因果 Mask(上三角矩阵,不含对角线)
        # 形状: (context_length, context_length)
        # 作用:
        #   mask[i, j] = 1  表示 j > i (未来位置)
        #   mask[i, j] = 0  表示 j <= i(当前或过去)
        #
        # 后续会把 mask==1 的位置填成 -inf,
        # 从而保证 softmax 后这些位置权重为 0
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        """
        x 形状: (batch_size, num_tokens, d_in)
        """

        b, num_tokens, d_in = x.shape  # b: batch_size, T: 序列长度

        # ========= 第一步:线性映射得到 Q, K, V =========
        # 形状: (b, T, d_out)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # ========= 第二步:计算注意力分数 QK^T =========
        # keys.transpose(1, 2) 变为 (b, d_out, T)
        # 乘完后 attn_scores 形状: (b, T, T)
        #
        # 对应公式:
        # S = QK^T
        
        # 注意这里不直接用keys.T是因为droput在矩阵里面。我们只需要转置前面两个。
        attn_scores = queries @ keys.transpose(1, 2)

        # ========= 第三步:加入因果 Mask =========
        # 将未来位置 (j > i) 填为 -inf
        # softmax 后这些位置权重会变成 0
        #
        # 对应公式:
        # S' = S + M
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )

        # ========= 第四步:缩放 + softmax =========
        # 除以 sqrt(d_k) 防止数值过大
        #
        # 对应公式:
        # A = softmax(S' / sqrt(d_k))
        attn_weights = torch.softmax(
            attn_scores / (keys.shape[-1] ** 0.5),
            dim=-1
        )

        # 对注意力权重做 dropout
        attn_weights = self.dropout(attn_weights)

        # ========= 第五步:加权求和 =========
        # Z = A V
        # 输出形状: (b, T, d_out)
        context_vec = attn_weights @ values

        return context_vec

关于masked_fill这个函数,我们简单说明下。

x = torch.tensor([[1., 2.],
                  [3., 4.]])
mask = torch.tensor([[False, True],
                     [False, False]])
x.masked_fill_(mask, -1)
print(x)


## 输出
tensor([[ 1., -1.],
        [ 3.,  4.]])

发生了什么?

  • mask 为 True 的位置是 (0,1)
  • 那个位置被改成 -1
  • 其他位置不变

self.mask.bool()[:num_tokens, :num_tokens], 按顺序解释:

  1. self.mask → 取 tensor
  2. .bool() → 转成布尔 tensor
  3. [:num_tokens, :num_tokens] → 取左上角 num_tokens × num_tokens 子矩阵

多头注意力机制


Multi-Head Attention 数学表达式

设输入为:

XRT×dmodelX \in \mathbb{R}^{T \times d_{\text{model}}}

头数为:

hh

每个头的维度为:

dh=dmodelhd_h = \frac{d_{\text{model}}}{h}

1. 线性投影

对于每个 head i=1,,hi = 1, \dots, h

Qi=XWQ(i),Ki=XWK(i),Vi=XWV(i)Q_i = X W_Q^{(i)}, \quad K_i = X W_K^{(i)}, \quad V_i = X W_V^{(i)}

其中:

WQ(i),WK(i),WV(i)Rdmodel×dhW_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_h}

2. 每个 head 的注意力

headi=softmax(QiKidh)Vi\text{head}_i = \text{softmax} \left( \frac{Q_i K_i^\top}{\sqrt{d_h}} \right) V_i

3. 拼接所有 head

Concat(head1,,headh)RT×(hdh)\text{Concat}(\text{head}_1, \dots, \text{head}_h) \in \mathbb{R}^{T \times (h d_h)}

4. 输出投影

MultiHead(X)=Concat(head1,,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O

其中:

WOR(hdh)×dmodelW_O \in \mathbb{R}^{(h d_h) \times d_{\text{model}}}

最终整体表达式

MultiHead(X)=Concat(softmax(XWQ(1)(XWK(1))dh)XWV(1),,softmax(XWQ(h)(XWK(h))dh)XWV(h))WO\text{MultiHead}(X) = \text{Concat}\Big( \text{softmax}\left(\frac{X W_Q^{(1)} (X W_K^{(1)})^\top}{\sqrt{d_h}}\right) X W_V^{(1)}, \dots,\\ \text{softmax}\left(\frac{X W_Q^{(h)} (X W_K^{(h)})^\top}{\sqrt{d_h}}\right) X W_V^{(h)} \Big) W_O
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads

        # 每个 head 的维度
        # 多头的核心思想:把总维度 d_out 拆成 num_heads 个子空间
        self.head_dim = d_out // num_heads  

        # ==========================
        # 为什么需要 projection?
        # ==========================
        # 如果直接用 X 做 attention:
        #   Q = K = V = X
        # 那么所有 head 都在同一个特征空间里做注意力,
        # 表达能力会受到限制。
        #
        # 加入可学习的线性映射后:
        #   Q = X W_Q
        #   K = X W_K
        #   V = X W_V
        #
        # 模型可以学习:
        #   - 哪些特征用于“查询”
        #   - 哪些特征用于“匹配”
        #   - 哪些特征用于“输出”
        #
        # 多头的关键:
        #   每个 head 使用不同的 W_Q, W_K, W_V,
        #   相当于在不同的线性子空间中做 attention。
        #
        # 这就是 multi-head 比 single-head 强的根本原因。
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # 输出投影层
        # 作用:
        #   把多个 head 的结果拼接后再做一次线性变换,
        #   让不同 head 之间的信息进行融合。
        #
        # 如果没有这一层,各个 head 是彼此独立的。
        self.out_proj = nn.Linear(d_out, d_out)

        self.dropout = nn.Dropout(dropout)

        # 因果 mask(上三角)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # ==========================
        # 1. 线性投影得到 Q, K, V
        # ==========================
        # 形状: (b, T, d_out)
        keys    = self.W_key(x)
        queries = self.W_query(x)
        values  = self.W_value(x)

        # ==========================
        # 2. 拆成多个 head
        # ==========================
        # (b, T, d_out)
        # -> (b, T, num_heads, head_dim)
        #
        # 本质:把一个大向量拆成多个小向量,
        # 每个 head 在自己的子空间中做注意力。
        keys    = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values  = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # ==========================
        # 3. 调整维度顺序
        # ==========================
        # (b, T, h, d)
        # -> (b, h, T, d)
        #
        # 这样可以对每个 head 并行计算注意力
        keys    = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values  = values.transpose(1, 2)

        # ==========================
        # 4. 每个 head 计算注意力
        # ==========================
        # (b, h, T, d) @ (b, h, d, T)
        # -> (b, h, T, T)
        attn_scores = queries @ keys.transpose(2, 3)

        # 因果 mask(未来位置置为 -inf)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # 缩放 + softmax
        attn_weights = torch.softmax(
            attn_scores / (self.head_dim ** 0.5),
            dim=-1
        )

        attn_weights = self.dropout(attn_weights)

        # ==========================
        # 5. 加权求和
        # ==========================
        # (b, h, T, T) @ (b, h, T, d)
        # -> (b, h, T, d)
        context = attn_weights @ values

        # ==========================
        # 6. 合并多个 head
        # ==========================
        # (b, h, T, d)
        # -> (b, T, h, d)
        context = context.transpose(1, 2)

        # -> (b, T, d_out)
        context = context.contiguous().view(b, num_tokens, self.d_out)

        # ==========================
        # 7. 输出投影
        # ==========================
        # 融合多个 head 的信息
        context = self.out_proj(context)

        return context


参考:

[1].《从零构建大模型》