注意力机制拓展-大模型知识点(程序员转行AI大模型学习)

3 阅读8分钟

在上篇文章注意力机制介绍了注意力机制,基于注意力机制,有 3 个延伸拓展的版本。

自注意力(self-attention)

原理

我们在注意力机制文章中,描述了用搜索引擎搜索“python 教程”查询词的例子,对应注意力机制中,Q 是我们提供的查询词“python 教程”、K 是搜索引擎给出的网页标题【python 教程、python 项目、java 教程】,可以理解 Q 和 K 来自 2 个不同的序列,Q 是由我们提供的序列,K 是搜索引擎提供的序列。但是,如果我们是用搜索引擎提供的其中一个网页标题去计算和其他网页标题的注意力相关性时,这就是自注意力机制,因为 Q 和 K 都是来自搜索引擎提供的同一个序列,依次计算这个序列中的其中一个元素(网页标题)和其他元素(其他网页标题)的注意力分数。

实现

我们提供一个序列 x,然后经过变换矩阵 W_q、W_k、W_v 转换,得到 Q、K、V,基于(Q、K、V)计算注意力分数,这就是自注意力。

# 输入x
input = x

# 转换
Q = W_q * x
K = W_k * x
V = W_v * x

# 计算注意力分数
...

掩码注意力

原理

假设我们在写论文,写了前 3 句,现在要写第 4 句,我们做的是把前 3 句内容输入给搜索引擎,然后通过查询词 Q(“第 4 句要写什么内容”)去搜索,注意:这时候搜索引擎是看不到 第 4 句内容的,因为还没有写,所以此时 K 为(“第 1 句内容”,“第 2 句内容”,“第 3 句内容”),计算相关性分数 Q*K.T,然后通过 softmax 进行归一化(记得缩放),得到权重之后,和网页内容进行加权求和,得到第 4 句要写的内容。

输入已写论文x:[第1句,第2句,第3句]
    ↓
查询词Q:第4句要写什么?
    ↓ 搜索,网页看不到第4句及之后的内容,可以理解被屏蔽了
网页标题K:[“第1句内容”,“第2句内容”,“第3句内容”]
    ↓ 计算相关性
得到注意力权重:[0.5,0.3,0.2]
    ↓
生成第4句=0.5*“第1句内容” + 0.3*“第2句内容” + 0.2*“第3句内容”

实现

新增1个下三角掩码矩阵:M,类似如下:
[0, -inf, -inf, -inf] # 此时,只能看到第1个词
[0, 0,    -inf, -inf] # 此时,只能看到第1、2个词
[0, 0,    0,    -inf] # 此时,只能看到第1、2、3个词
[0, 0,    0,    0   ] # 此时,能看到所有词

掩码注意力计算:
Attention = softmax((Q * K.T) + M) * V
# 先通过full函数创建一个1 * seq_len * seq_len的矩阵
mask = torch.full((1, args.max_seq_len, args.max_seq_len), float("-inf"))
通过triu创建一个上三角矩阵
mask = torch.triu(mask, diagonal=1)

scores = scores + mask[:, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)

多头注意力

原理

回到我们通过搜索引擎查询““python 教程”的场景,目前我们都是一组 Q、K、V 计算注意力分数,叫单头注意力。如果我们从多个角度去计算注意力分数,每个角度关注不同的特征,就是多头注意力。通过多头注意力,可以提取丰富的特征。

查询词Q:python教程
    ↓ 多角度搜索
角度1:语法角度
    K1:【“python语法”,“pyhton基础”,“python进阶”】
    V1:【“语法讲解”,“基础教程”,“进阶指南”】
    注意力1:【0.7,0.2,0.1】
    输出1 = 0.7 * “语法讲解” + 0.2 * "基础教程" + 0.1 * "进阶指南"

角度2:项目角度
    K2:【“python项目”,“pyhton实战”,“python案例”】
    V2:【“贪吃蛇”,“数据分析”,“web开发”】
    注意力1:【0.3,0.5,0.2】
    输出2 = 0.3 * “贪吃蛇” + 0.5 * "数据分析" + 0.2 * "web开发"

角度3:就业角度
    K3:【“python就业”,“pyhton面试”,“python薪资”】
    V3:【“就业指导”,“面试技巧”,“薪资分析”】
    注意力1:【0.2,0.3,0.5】
    输出3 = 0.2 * “就业指导” + 0.3 * "面试技巧" + 0.5 * "薪资分析"

    ↓ 拼接所有角度输出
最终输出 = 【输出1, 输出2, 输出3】

实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        """
        多头注意力机制
        
        Args:
            d_model: 模型维度
            n_heads: 注意力头的数量
        """
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads  # 每个头的维度

        # 线性变换矩阵
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, X, mask=None):
        """
        前向传播
        
        Args:
            X: (batch_size, seq_len, d_model)
            mask: (batch_size, seq_len, seq_len) 可选
            
        Returns:
            output: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = X.shape

        # 1. 计算Q、K、V
        Q = self.W_Q(X)  # (batch_size, seq_len, d_model)
        K = self.W_K(X)  # (batch_size, seq_len, d_model)
        V = self.W_V(X)  # (batch_size, seq_len, d_model)

        # 2. reshape为多头
        # (batch_size, seq_len, n_heads, d_head)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_head)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_head)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_head)

        # 3. 转置以便计算注意力
        # (batch_size, n_heads, seq_len, d_head)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # 4. 计算注意力分数
        # Q: (batch_size, n_heads, seq_len, d_head)
        # K.transpose(-2, -1): (batch_size, n_heads, d_head, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)

        # 5. 应用掩码(如果提供)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # 6. Softmax归一化
        attention_weights = F.softmax(scores, dim=-1)

        # 7. 加权求和
        # attention_weights: (batch_size, n_heads, seq_len, seq_len)
        # V: (batch_size, n_heads, seq_len, d_head)
        output = torch.matmul(attention_weights, V)

        # 8. 转置回来
        # (batch_size, seq_len, n_heads, d_head)
        output = output.transpose(1, 2)

        # 9. concat所有头
        # (batch_size, seq_len, d_model)
        output = output.contiguous().view(batch_size, seq_len, d_model)

        # 10. 线性变换
        output = self.W_O(output)

        return output, attention_weights


def demo_multi_head_attention():
    """
    演示多头注意力机制
    """
    print("=== 多头注意力机制演示 ===\n")

    # 参数设置
    batch_size = 2
    seq_len = 4
    d_model = 512
    n_heads = 8
    d_head = d_model // n_heads  # 64

    print(f"参数设置:")
    print(f"  batch_size: {batch_size}")
    print(f"  seq_len: {seq_len}")
    print(f"  d_model: {d_model}")
    print(f"  n_heads: {n_heads}")
    print(f"  d_head: {d_head}\n")
    
    # 创建输入数据
    X = torch.randn(batch_size, seq_len, d_model)
    print(f"输入形状: {X.shape}")  # (2, 4, 512)
    
    # 创建多头注意力层
    mha = MultiHeadAttention(d_model, n_heads)
    
    # 计算多头注意力
    output, attention_weights = mha.forward(X)
    
    print(f"输出形状: {output.shape}")  # (2, 4, 512)
    print(f"注意力权重形状: {attention_weights.shape}")  # (2, 8, 4, 4)
    
    # 显示每个头的注意力权重
    print(f"\n=== 每个头的注意力权重 ===")
    for head in range(n_heads):
        print(f"\n头 {head + 1}:")
        print(f"  注意力权重:")
        for i in range(seq_len):
            print(f"    位置{i + 1}: {attention_weights[0, head, i].tolist()}")


def demo_with_mask():
    """
    演示带掩码的多头注意力
    """
    print("\n=== 带掩码的多头注意力演示 ===\n")
    
    # 参数设置
    batch_size = 1
    seq_len = 5
    d_model = 512
    n_heads = 4
    
    # 创建输入数据
    X = torch.randn(batch_size, seq_len, d_model)
    print(f"输入形状: {X.shape}")  # (1, 5, 512)
    
    # 创建下三角掩码
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
    print(f"掩码形状: {mask.shape}")  # (1, 5, 5)
    print(f"掩码矩阵:")
    print(mask[0].tolist())
    
    # 创建多头注意力层
    mha = MultiHeadAttention(d_model, n_heads)
    
    # 计算多头注意力(带掩码)
    output, attention_weights = mha.forward(X, mask)
    
    print(f"\n输出形状: {output.shape}")  # (1, 5, 512)
    print(f"注意力权重形状: {attention_weights.shape}")  # (1, 4, 5, 5)
    
    # 显示掩码后的注意力权重
    print(f"\n=== 掩码后的注意力权重 ===")
    for head in range(n_heads):
        print(f"\n头 {head + 1}:")
        print(f"  注意力权重:")
        for i in range(seq_len):
            print(f"    位置{i + 1}: {attention_weights[0, head, i].tolist()}")


def demo_step_by_step():
    """
    逐步演示多头注意力的计算过程
    """
    print("\n=== 逐步演示多头注意力 ===\n")
    
    # 简化参数
    batch_size = 1
    seq_len = 3
    d_model = 12  # 简化维度
    n_heads = 2
    d_head = d_model // n_heads  # 6
    
    # 创建输入数据
    X = torch.randn(batch_size, seq_len, d_model)
    print(f"1. 输入 X:")
    print(f"   形状: {X.shape}")  # (1, 3, 12)
    print(f"   数据:\n{X[0]}\n")
    
    # 创建多头注意力层
    mha = MultiHeadAttention(d_model, n_heads)
    
    # 1. 计算Q、K、V
    Q = mha.W_Q(X)
    K = mha.W_K(X)
    V = mha.W_V(X)
    print(f"2. 计算Q、K、V:")
    print(f"   Q形状: {Q.shape}")  # (1, 3, 12)
    print(f"   K形状: {K.shape}")  # (1, 3, 12)
    print(f"   V形状: {V.shape}")  # (1, 3, 12)
    
    # 2. reshape为多头
    Q = Q.view(batch_size, seq_len, n_heads, d_head)
    K = K.view(batch_size, seq_len, n_heads, d_head)
    V = V.view(batch_size, seq_len, n_heads, d_head)
    print(f"3. reshape为多头:")
    print(f"   Q形状: {Q.shape}")  # (1, 3, 2, 6)
    print(f"   K形状: {K.shape}")  # (1, 3, 2, 6)
    print(f"   V形状: {V.shape}")  # (1, 3, 2, 6)
    
    # 3. 转置
    Q = Q.transpose(1, 2)
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)
    print(f"4. 转置:")
    print(f"   Q形状: {Q.shape}")  # (1, 2, 3, 6)
    print(f"   K形状: {K.shape}")  # (1, 2, 3, 6)
    print(f"   V形状: {V.shape}")  # (1, 2, 3, 6)
    
    # 4. 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_head)
    print(f"5. 计算注意力分数:")
    print(f"   scores形状: {scores.shape}")  # (1, 2, 3, 3)
    print(f"   头1的注意力分数:")
    print(f"   {scores[0, 0]}\n")
    
    # 5. Softmax
    attention_weights = F.softmax(scores, dim=-1)
    print(f"6. Softmax归一化:")
    print(f"   attention_weights形状: {attention_weights.shape}")  # (1, 2, 3, 3)
    print(f"   头1的注意力权重:")
    print(f"   {attention_weights[0, 0]}\n")
    
    # 6. 加权求和
    output = torch.matmul(attention_weights, V)
    print(f"7. 加权求和:")
    print(f"   output形状: {output.shape}")  # (1, 2, 3, 6)
    print(f"   头1的输出:")
    print(f"   {output[0, 0]}\n")
    
    # 7. 转置回来
    output = output.transpose(1, 2)
    print(f"8. 转置回来:")
    print(f"   output形状: {output.shape}")  # (1, 3, 2, 6)
    
    # 8. concat
    output = output.contiguous().view(batch_size, seq_len, d_model)
    print(f"9. concat所有头:")
    print(f"   output形状: {output.shape}")  # (1, 3, 12)
    
    # 9. 线性变换
    output = mha.W_O(output)
    print(f"10. 线性变换:")
    print(f"   输出形状: {output.shape}")  # (1, 3, 12)


if __name__ == "__main__":
    # 演示1:基本多头注意力
    demo_multi_head_attention()
    
    # 演示2:带掩码的多头注意力
    demo_with_mask()
    
    # 演示3:逐步演示
    demo_step_by_step()