Step-02 实现注意力机制

92 阅读2分钟

一、背景

从RNN 到 Gated RNN(LSTM\GRU)都无法解决Long Sequence中的hidden layer记忆管理!

  • 抓重点

  • 单token不同上下文语义不同

二、本质

何谓注意力机制?注意力在机器学习中本质上是一个加权平均机制,让模型可以更“聚焦”于输入中的重要部分,忽略无关部分。

  • 并行计算:自注意力机制适合并行计算,因此在处理长序列时效率更高。

  • 长距离依赖:自注意力可以直接捕捉句中任何两个词之间的依赖关系,即使距离较远。

  • 自适应性:每个词的上下文表示会动态变化,模型可以在不同语境下灵活调整每个词对整体句子理解的贡献。

三、最佳实践

self-attention mechanism is also called "scaled dot-product attention

  • Dot-product : 上下文中谁重要利用了矩阵的点积运算,比较标量大小值

  • scaled:缩放指的是的【下式】分母目的是当点积结果非常大时,softmax 函数的梯度会变得非常小,这会影响模型的训练。通过缩放,可以将点积结果控制在一个合理的范围内,避免梯度消失。

    Image

Image

四、快速开始

Image

五、实际工作

class RagMultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)  

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) 

        return context_vec