import torch.nn as nn
import numpy as np
import torch
import math
# 多头注意力
class MHA(nn.Module):
def __init__(self, num_head, dimension_k, dimension_v, d_k, d_v, d_o):
# d_k表示head dimension,d_k * num_head 就是embedding的长度
super().__init__()
self.num_head = num_head
self.d_k = d_k
self.d_v = d_v
self.d_o = d_o
self.fc_q = nn.Linear(dimension_k, num_head * d_k)
self.fc_k = nn.Linear(dimension_k, num_head * d_k)
self.fc_v = nn.Linear(dimension_v, num_head * d_v)
self.fc_o = nn.Linear(num_head * d_v, d_o)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask):
batch, n_q, dimension_q = q.size()
batch, n_k, dimension_k = k.size()
batch, n_v, dimension_v = v.size()
q = self.fc_q(q)
k = self.fc_k(k)
v = self.fc_v(v)
q = q.view(batch, n_q, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_q, self.d_k)
k = k.view(batch, n_k, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, self.d_k)
v = v.view(batch, n_v, self.num_head, self.d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, self.d_v)
attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
mask = mask.repeat(self.num_head, 1, 1)
attention = attention + mask
attention = self.softmax(attention)
output = torch.matmul(attention, v)
output = output.view(self.num_head, batch, n_q, self.d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
output = self.fc_o(output)
return attention, output
# Multi query attention
class MQA(nn.Module):
def __init__(self, num_head, dimension_k, dimension_v, d_k, d_v, d_o):
super().__init__()
self.num_head = num_head
self.d_k = d_k
self.d_v = d_v
self.d_o = d_o
self.fc_q = nn.Linear(dimension_k, num_head * d_k)
self.fc_k = nn.Linear(dimension_k, d_k)
self.fc_v = nn.Linear(dimension_v, d_v)
self.fc_o = nn.Linear(num_head * d_v, d_o)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask):
batch, n_q, dimension_q = q.size()
batch, n_k, dimension_k = k.size()
batch, n_v, dimension_v = v.size()
q = self.fc_q(q)
k = self.fc_k(k)
v = self.fc_v(v)
q = q.view(batch, n_q, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_q, self.d_k)
k = k.repeat(self.num_head, 1, 1)
v = v.repeat(self.num_head, 1, 1)
attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
mask = mask.repeat(self.num_head, 1, 1)
attention = attention + mask
attention = self.softmax(attention)
output = torch.matmul(attention, v)
output = output.view(self.num_head, batch, n_q, self.d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
output = self.fc_o(output)
return attention, output
batch = 10
num_head = 8
n_q, n_k, n_v = 2, 4, 4 # sequence 长度
dimension_q, dimension_k, dimension_v = 128, 128, 64 # embedding的长度
d_k, d_v, d_o = 16, 16, 8
q = torch.randn(batch, n_q, dimension_q)
k = torch.randn(batch, n_k, dimension_k)
v = torch.randn(batch, n_v, dimension_v)
mask = torch.full((batch, n_q, n_k), -np.inf)
mask = torch.triu(mask)
mha = MHA(num_head, dimension_k, dimension_v, d_k, d_v, d_o)
attention, output = mha(q, k, v, mask)
print(attention.size(), output.size())
mqa = MQA(num_head, dimension_k, dimension_v, d_k, d_v, d_o)
attention, output = mqa(q, k, v, mask)
print(attention.size(), output.size())
核心背景:这段代码实现了Transformer模型的核心组件——注意力机制,分为两种:MHA(多头注意力,最常用)和MQA(多查询注意力,更高效),两者逻辑相似,核心区别在“K/V是否共享”,后面会逐行对比。
第一部分:导入所需库(逐行拆解)
import torch.nn as nn
import numpy as np
import torch
import math
逐行详细解释(含用途+语法+区别):
-
import torch.nn as nn- PyTorch专属:torch是PyTorch的核心库,nn是torch下的“neural network(神经网络)”模块
- 作用:提供所有搭建神经网络需要的组件(比如线性层、激活函数、模型父类等)
- as nn:给这个模块起一个简称“nn”,后面用的时候不用写全称torch.nn,简化代码(基础Python的“别名”用法)
-
import numpy as np- 普通Python库:numpy是Python用于数值计算的核心库,主要处理数组、矩阵
- 作用:这里只用来生成“负无穷”(后面做mask用),你只要知道它能生成特殊数值即可
- as np:同样是起别名,简化写法
-
import torch- PyTorch核心库:所有PyTorch的核心功能都在这里,比如生成张量(类似Python的列表/数组,但能做GPU加速、矩阵运算)、矩阵乘法等
- 作用:后面生成输入数据(Q/K/V)、做矩阵运算、处理张量形状,都要用到这个库
-
import math- 普通Python内置库:提供基础数学运算(比如平方根、三角函数等)
- 作用:这里只用来计算“根号下d_k”(注意力计算的缩放因子),避免数值太大导致softmax失效
第二部分:定义MHA(多头注意力)类(核心代码,逐行拆解)
先明确:类(class)是基础Python语法,用来封装一组相关的属性(变量)和方法(函数);这里的MHA类,是一个“注意力模型”,里面包含了“初始化模型结构”和“执行计算”两个核心方法。
第一步:定义类,继承PyTorch的模型父类
# 多头注意力
class MHA(nn.Module):
-
# 多头注意力:单行注释(基础Python),说明这个类的功能,不影响代码运行 -
class MHA(nn.Module):class MHA::定义一个名为MHA的类(基础Python语法)nn.Module:PyTorch专属,这是PyTorch中所有神经网络模型的“父类”- 继承(nn.Module)的作用:让MHA类自动拥有PyTorch模型的所有基础功能(比如参数管理、前向传播、GPU加速等),不用我们自己写
- 注意:类定义后面有冒号,下面的代码要缩进(基础Python缩进规则,缩进代表属于这个类)
第二步:__init__方法(初始化模型结构,定义所有层和参数)
__init__是基础Python的“构造函数”,作用是:当我们创建MHA类的实例(比如后面的mha = MHA(...))时,自动执行这个方法,初始化模型的参数和网络层。
def __init__(self, num_head, dimension_k, dimension_v, d_k, d_v, d_o):
# d_k表示head dimension,d_k * num_head 就是embedding的长度
super().__init__()
self.num_head = num_head
self.d_k = d_k
self.d_v = d_v
self.d_o = d_o
self.fc_q = nn.Linear(dimension_k, num_head * d_k)
self.fc_k = nn.Linear(dimension_k, num_head * d_k)
self.fc_v = nn.Linear(dimension_v, num_head * d_v)
self.fc_o = nn.Linear(num_head * d_v, d_o)
self.softmax = nn.Softmax(dim=2)
逐行拆解(每一行都讲透):
-
def __init__(self, num_head, dimension_k, dimension_v, d_k, d_v, d_o):-
def:定义函数(基础Python),这里是类里面的方法,叫__init__(固定名称,不能改) -
self:基础Python类的固定参数,代表“类的实例本身”,后面用self.xxx,就是给这个实例设置属性(比如self.num_head,就是实例的“头数量”属性) -
后面的6个参数(num_head, dimension_k, ..., d_o):是我们创建MHA实例时,需要手动传入的参数,每个参数的含义(关键,必须懂):
num_head:注意力的“头数量”(比如8头、16头),多头的作用是让模型能关注到不同维度的信息dimension_k:输入“键(K)”的特征维度(比如后面测试代码中是128)dimension_v:输入“值(V)”的特征维度(比如后面测试代码中是64)d_k:每个头的Q/K维度(比如后面测试代码中是16),注释里写了“d_k * num_head 就是embedding的长度”,意思是:所有头的Q/K合起来,维度等于输入的embedding长度d_v:每个头的V维度(比如后面测试代码中是16)d_o:注意力机制的最终输出维度(比如后面测试代码中是8)
-
-
# d_k表示head dimension,d_k * num_head 就是embedding的长度:单行注释,解释d_k的含义,帮助理解代码,不影响运行 -
super().__init__():- PyTorch专属+基础Python:super()是调用父类的方法,这里调用父类(nn.Module)的__init__方法
- 作用:必须写!因为我们继承了nn.Module,只有调用父类的初始化,才能让MHA类拥有PyTorch模型的所有功能(比如参数管理)
- 固定写法:只要是继承nn.Module的类,__init__方法里第一行必须写这个
-
self.num_head = num_head、self.d_k = d_k、self.d_v = d_v、self.d_o = d_o:- 基础Python:把传入的参数,赋值给类的实例(self),变成实例的属性
- 作用:后面的方法(比如forward)需要用到这些参数,这样就能直接通过self.xxx调用,不用再重复传入
-
self.fc_q = nn.Linear(dimension_k, num_head * d_k):-
self.fc_q:给这个线性层起个名字叫fc_q(fc=fully connected,全连接层),作为实例的属性 -
nn.Linear(in_features, out_features):PyTorch专属,定义一个“线性层”(也叫全连接层) -
线性层的作用:做一次线性变换(y = x × W + b,W是权重,b是偏置,PyTorch会自动初始化W和b,不用我们管)
-
参数解释:
in_features=dimension_k:输入到这个线性层的数据维度(这里是Q的输入维度,和K的输入维度一样,都是dimension_k)out_features=num_head * d_k:线性层的输出维度(把Q映射到“头数量×每个头的Q维度”,这样才能拆分成多个头)
-
简单说:这个线性层的作用,是把原始的Q,转换成适合多头注意力计算的维度
-
-
self.fc_k = nn.Linear(dimension_k, num_head * d_k):- 和fc_q逻辑完全一样,只是作用于“键(K)”
- 输入维度:dimension_k(K的原始维度),输出维度:num_head * d_k(拆分多头后的K维度)
-
self.fc_v = nn.Linear(dimension_v, num_head * d_v):- 作用于“值(V)”的线性层
- 输入维度:dimension_v(V的原始维度),输出维度:num_head * d_v(拆分多头后的V维度)
- 注意:V的输入维度(dimension_v)可以和Q/K的维度(dimension_k)不一样(比如后面测试代码中,dimension_k=128,dimension_v=64)
-
self.fc_o = nn.Linear(num_head * d_v, d_o):- 最终的输出线性层,作用是:把多个头的计算结果“拼接”后,映射到我们需要的最终输出维度d_o
- 输入维度:num_head * d_v(所有头的V结果拼接后的维度)
- 输出维度:d_o(最终想要的输出维度)
-
self.softmax = nn.Softmax(dim=2):nn.Softmax():PyTorch专属,激活函数,作用是把一组数值归一化到0~1之间,且所有数值的和为1(用来计算注意力权重)dim=2:关键参数,指定“在哪个维度上做归一化”- 这里的dim=2,对应后面的注意力矩阵形状(batch×num_head, 序列长度, 序列长度),dim=2就是在“序列长度”这个维度上归一化,确保每个位置的注意力权重和为1
第三步:forward方法(前向传播,真正执行注意力计算)
forward是PyTorch模型的“固定方法名”,作用是:当我们把数据传入模型(比如mha(q, k, v, mask))时,自动执行forward方法,完成一次前向计算,输出结果。
核心逻辑:输入Q/K/V → 线性变换 → 拆分多头 → 计算注意力权重 → 用权重加权V → 拼接多头 → 最终线性变换 → 输出结果
def forward(self, q, k, v, mask):
batch, n_q, dimension_q = q.size()
batch, n_k, dimension_k = k.size()
batch, n_v, dimension_v = v.size()
q = self.fc_q(q)
k = self.fc_k(k)
v = self.fc_v(v)
q = q.view(batch, n_q, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_q, self.d_k)
k = k.view(batch, n_k, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, self.d_k)
v = v.view(batch, n_v, self.num_head, self.d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, self.d_v)
attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
mask = mask.repeat(self.num_head, 1, 1)
attention = attention + mask
attention = self.softmax(attention)
output = torch.matmul(attention, v)
output = output.view(self.num_head, batch, n_q, self.d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
output = self.fc_o(output)
return attention, output
逐行拆解(重点讲PyTorch语法和计算逻辑,每个步骤的目的都讲清楚):
-
def forward(self, q, k, v, mask):-
forward:固定方法名,PyTorch会自动识别,不能随便改 -
self:依然是类的实例本身 -
4个输入参数(必须懂,对应后面的测试数据):
q:查询(Query),形状是 [batch, n_q, dimension_q](后面会具体讲形状含义)k:键(Key),形状是 [batch, n_k, dimension_k]v:值(Value),形状是 [batch, n_v, dimension_v]mask:掩码,形状是 [batch, n_q, n_k],作用是“屏蔽无效位置”(比如解码器中,不能看到未来的序列),后面会详细讲
-
-
batch, n_q, dimension_q = q.size():-
q.size():PyTorch专属,获取张量q的“形状”(类似Python列表的len(),但能获取多维数组的形状) -
张量(Tensor):PyTorch的核心数据结构,类似Python的列表/数组,但支持GPU加速、矩阵运算,这里的q/k/v都是张量
-
形状含义(关键,必须懂):
batch:批次大小,意思是“一次同时处理多少条数据”(比如后面测试代码中batch=10,就是一次处理10条数据)n_q:Q的“序列长度”(比如后面测试代码中n_q=2,就是每个Q有2个token)dimension_q:Q的“特征维度”(和前面的dimension_k一致,比如128)
-
举例:如果q的形状是 [10, 2, 128],就表示:10个批次,每个批次有2个序列,每个序列的特征维度是128
-
后面两行
k.size()、v.size()逻辑完全一样,分别获取k和v的形状,拆分出batch、n_k(k的序列长度)、n_v(v的序列长度)等参数
-
-
q = self.fc_q(q)、k = self.fc_k(k)、v = self.fc_v(v):- 调用前面__init__中定义的线性层,对Q/K/V分别做线性变换
- 举例:q原本的形状是 [batch, n_q, dimension_q] = [10, 2, 128],经过fc_q(输入128,输出8×16=128)变换后,形状还是 [10, 2, 128](因为dimension_k=128,num_head×d_k=8×16=128)
- 目的:把原始Q/K/V,转换成适合多头注意力计算的维度(虽然形状没变,但内部数值已经做了线性映射)
-
q = q.view(batch, n_q, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_q, self.d_k):-
这一行是“拆分多头”的核心代码,很长,拆成4个步骤逐字讲,每个步骤都讲目的和效果:
-
步骤1:
q.view(batch, n_q, self.num_head, self.d_k).view(shape):PyTorch专属,改变张量的形状,但不改变内部数据(类似Python numpy的reshape)- 原始q的形状:[batch, n_q, num_head×d_k] → 比如 [10, 2, 128]
- 变换后形状:[batch, n_q, num_head, d_k] → 比如 [10, 2, 8, 16]
- 目的:把“num_head×d_k”这个维度,拆分成“num_head(头数)”和“d_k(每个头的维度)”,这样就能把Q拆分成多个头
-
步骤2:
.permute(2, 0, 1, 3).permute(维度顺序):PyTorch专属,交换张量的维度顺序(比如原来的维度是0:batch, 1:n_q, 2:num_head, 3:d_k,permute(2,0,1,3)就是把维度2(num_head)放到最前面)- 变换后形状:[num_head, batch, n_q, d_k] → 比如 [8, 10, 2, 16]
- 目的:把“头数”维度放到最前面,方便后续所有头并行计算(PyTorch会自动并行处理多个头)
-
步骤3:
.contiguous()- PyTorch专属:确保张量的内存是“连续的”
- 原因:permute交换维度后,张量的内存会变得不连续,后面再用view()会报错,所以必须加contiguous(),让内存重新整理成连续的
- 不用深究内存细节,记住:permute之后,加一个contiguous(),再用view(),就不会报错
-
步骤4:
.view(-1, n_q, self.d_k)-1:PyTorch专属,表示“自动计算这个维度的大小”,不用我们手动算- 前面的形状是 [num_head, batch, n_q, d_k],view(-1, n_q, d_k) 就是把前两个维度(num_head和batch)合并成一个维度
- 合并后形状:[num_head×batch, n_q, d_k] → 比如 [8×10=80, 2, 16]
- 目的:把“多个头+多个批次”合并成一个大批次,让PyTorch能一次性并行计算所有头、所有批次的Q,提高效率
-
总结这一行:把Q从 [batch, n_q, num_head×d_k] → 拆分成 [num_head×batch, n_q, d_k],为后续注意力计算做准备
-
后面两行
k.view(...)、v.view(...)逻辑完全一样,只是针对k和v,最终k的形状是 [num_head×batch, n_k, d_k],v的形状是 [num_head×batch, n_v, d_v]
-
-
attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k):-
这一行是“计算注意力相似度”的核心,拆成3部分讲:
-
部分1:
torch.matmul(a, b)- PyTorch专属:矩阵乘法(和数学中的矩阵相乘一致)
- 这里a是q(形状 [num_head×batch, n_q, d_k]),b是k.transpose(-1, -2)(后面讲)
- 矩阵乘法规则:a的最后一个维度,必须和b的倒数第二个维度相等;结果的形状是 [a的前n-1个维度, b的最后一个维度]
- 举例:q是 [80, 2, 16],k.transpose后是 [80, 16, 4],matmul后形状是 [80, 2, 4]
-
部分2:
k.transpose(-1, -2).transpose(dim1, dim2):PyTorch专属,交换两个维度的位置-1:表示“最后一个维度”,-2:表示“倒数第二个维度”- k原本的形状是 [num_head×batch, n_k, d_k],transpose(-1,-2)后,形状变成 [num_head×batch, d_k, n_k]
- 目的:因为矩阵乘法需要“q的最后一个维度(d_k)”和“k的倒数第二个维度(d_k)”相等,所以要转置k的最后两个维度
-
部分3:
/ math.sqrt(self.d_k)- math.sqrt():普通Python数学方法,计算平方根(比如d_k=16,sqrt(16)=4)
- 目的:“缩放因子”,防止d_k太大时,q和k的矩阵乘法结果数值太大,导致softmax后数值趋近于0或1,梯度消失(简单说:避免计算出错)
-
总结这一行:计算每个Q和每个K的相似度,得到注意力矩阵,形状是 [num_head×batch, n_q, n_k](比如 [80, 2, 4])
-
-
mask = mask.repeat(self.num_head, 1, 1):.repeat(次数):PyTorch专属,复制张量的指定维度(括号里的3个数字,对应张量的3个维度)- 原始mask的形状:[batch, n_q, n_k](比如 [10, 2, 4])
repeat(self.num_head, 1, 1):表示“第一个维度复制num_head次,第二、三个维度复制1次(不复制)”- 复制后mask的形状:[num_head×batch, n_q, n_k](比如 [8×10=80, 2, 4]),和前面的attention矩阵形状一致
- 目的:因为attention矩阵是“num_head×batch”开头,而原始mask是“batch”开头,为了能和attention矩阵相加,需要把mask复制num_head次,匹配维度
-
attention = attention + mask:- 基础Python的加法,但这里是“张量对应位置相加”(PyTorch自动支持)
- mask的作用:mask中无效位置的值是“-np.inf(负无穷)”,attention矩阵中对应位置加上负无穷后,后续softmax会把这个位置的权重变成0(因为softmax(-inf)=0)
- 简单说:通过加法,屏蔽掉无效位置的注意力(比如解码器中,不能关注到未来的序列,就把未来位置的mask设为负无穷)
-
attention = self.softmax(attention):- 调用前面定义的softmax激活函数,在dim=2(序列长度维度)上做归一化
- 效果:attention矩阵中每个位置的数值,都会变成0~1之间,且每个“n_k维度”(第二个序列长度)的和为1
- 目的:把“相似度”转换成“注意力权重”,权重越大,说明这个K对应的V越重要
-
output = torch.matmul(attention, v):- 矩阵乘法:用注意力权重(attention)对值(v)做加权求和
- attention的形状:[num_head×batch, n_q, n_k]
- v的形状:[num_head×batch, n_v, d_v](注意:n_k和n_v通常相等,后面测试代码中都是4)
- 相乘后output的形状:[num_head×batch, n_q, d_v](比如 [80, 2, 16])
- 目的:把重要的V(权重高的)保留,不重要的V(权重低的)弱化,得到每个Q对应的加权结果
-
output = output.view(self.num_head, batch, n_q, self.d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1):-
这一行是“拼接多头”的核心代码,和前面“拆分多头”的逻辑相反,拆成4个步骤:
-
步骤1:
output.view(self.num_head, batch, n_q, self.d_v)- 原始output形状:[num_head×batch, n_q, d_v] → 比如 [80, 2, 16]
- 变换后形状:[num_head, batch, n_q, d_v] → 比如 [8, 10, 2, 16]
- 目的:把“num_head×batch”这个合并的维度,拆分成“num_head”和“batch”,恢复成拆分多头前的维度顺序(只是头数在最前面)
-
步骤2:
.permute(1, 2, 0, 3)- 交换维度顺序:把原来的 [num_head, batch, n_q, d_v] → 变成 [batch, n_q, num_head, d_v]
- 目的:把“batch”维度放回最前面,“头数”维度放到倒数第二个,方便后续拼接
-
步骤3:
.contiguous()- 和前面一样,permute后内存不连续,加contiguous()避免后续view()报错
-
步骤4:
.view(batch, n_q, -1)-1:自动计算维度大小,这里是把“num_head”和“d_v”两个维度合并成一个维度(num_head×d_v)- 变换后形状:[batch, n_q, num_head×d_v] → 比如 [10, 2, 8×16=128]
- 目的:把多个头的计算结果,拼接成一个完整的特征向量,方便后续的线性层处理
-
-
output = self.fc_o(output):- 调用最终的线性层,把拼接后的特征向量(形状 [batch, n_q, num_head×d_v]),映射到我们需要的最终输出维度d_o
- 举例:拼接后是 [10, 2, 128],fc_o的输入是128,输出是8,所以最终output形状是 [10, 2, 8]
-
return attention, output:-
基础Python函数的return语句,返回两个结果:
attention:注意力权重矩阵(形状 [num_head×batch, n_q, n_k]),可以用来查看模型关注了哪些位置output:注意力机制的最终输出(形状 [batch, n_q, d_o]),作为后续网络的输入
-
第三部分:定义MQA(多查询注意力)类(逐行拆解,重点讲和MHA的区别)
MQA和MHA的核心逻辑完全一致,唯一的区别是:MHA的K/V是“多头”(每个头都有独立的K/V),MQA的K/V是“单头”(所有头共享一个K/V) ,所以代码只有少量差异,重点讲差异部分,相同部分简要带过。
# Multi query attention
class MQA(nn.Module):
def __init__(self, num_head, dimension_k, dimension_v, d_k, d_v, d_o):
super().__init__()
self.num_head = num_head
self.d_k = d_k
self.d_v = d_v
self.d_o = d_o
self.fc_q = nn.Linear(dimension_k, num_head * d_k)
self.fc_k = nn.Linear(dimension_k, d_k)
self.fc_v = nn.Linear(dimension_v, d_v)
self.fc_o = nn.Linear(num_head * d_v, d_o)
self.softmax = nn.Softmax(dim=2)
逐行拆解(重点讲和MHA的区别):
-
# Multi query attention:注释,说明这个类是多查询注意力 -
class MQA(nn.Module)::和MHA一样,继承PyTorch的模型父类,不多说 -
def __init__(self, ...)::参数和MHA完全一样,不多说 -
super().__init__():固定写法,不多说 -
self.num_head = num_head等4行:和MHA一样,保存参数,不多说 -
self.fc_q = nn.Linear(dimension_k, num_head * d_k):和MHA完全一样,Q依然是“多头”,输出维度是num_head×d_k -
self.fc_k = nn.Linear(dimension_k, d_k):和MHA的核心区别1- MHA的fc_k:输出维度是 num_head×d_k(多头K)
- MQA的fc_k:输出维度是 d_k(单头K)
- 原因:MQA的所有头,共享一个K,所以不需要拆分多头,只需要单头维度即可
-
self.fc_v = nn.Linear(dimension_v, d_v):和MHA的核心区别2- MHA的fc_v:输出维度是 num_head×d_v(多头V)
- MQA的fc_v:输出维度是 d_v(单头V)
- 原因:和K一样,MQA的所有头共享一个V
-
self.fc_o = nn.Linear(num_head * d_v, d_o)、self.softmax = nn.Softmax(dim=2):和MHA完全一样,不多说
MQA的forward方法(逐行拆解,重点讲和MHA的区别)
def forward(self, q, k, v, mask):
batch, n_q, dimension_q = q.size()
batch, n_k, dimension_k = k.size()
batch, n_v, dimension_v = v.size()
q = self.fc_q(q)
k = self.fc_k(k)
v = self.fc_v(v)
q = q.view(batch, n_q, self.num_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_q, self.d_k)
k = k.repeat(self.num_head, 1, 1)
v = v.repeat(self.num_head, 1, 1)
attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
mask = mask.repeat(self.num_head, 1, 1)
attention = attention + mask
attention = self.softmax(attention)
output = torch.matmul(attention, v)
output = output.view(self.num_head, batch, n_q, self.d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
output = self.fc_o(output)
return attention, output
逐行拆解(相同部分简要带过,重点讲差异):
-
前6行(def forward到v = self.fc_v(v)):和MHA完全一样,获取张量形状、对Q/K/V做线性变换,不多说
-
q = q.view(...):和MHA完全一样,Q依然要拆分多头,形状变成 [num_head×batch, n_q, d_k],不多说 -
k = k.repeat(self.num_head, 1, 1):和MHA的核心区别3- MHA的k:是通过view()拆分多头,变成 [num_head×batch, n_k, d_k]
- MQA的k:因为fc_k输出的是单头(形状 [batch, n_k, d_k]),所以需要用repeat()复制num_head次,变成 [num_head×batch, n_k, d_k]
- 目的:Q是多头(num_head×batch),K需要和Q的维度匹配,才能做矩阵乘法,所以把单头K复制num_head次,让所有头共享同一个K
-
v = v.repeat(self.num_head, 1, 1):和MHA的核心区别4- 和K一样,MQA的V是单头(形状 [batch, n_v, d_v]),复制num_head次,变成 [num_head×batch, n_v, d_v],和Q的维度匹配
- 目的:所有头共享同一个V
-
后面的代码(计算attention、mask、softmax、output拼接、fc_o):和MHA完全一样,不多说
MQA和MHA的核心区别总结(必懂):
- MHA:Q/K/V都是多头(每个头有独立的Q/K/V),参数多、计算慢,但效果可能更好
- MQA:Q是多头,K/V是单头(所有头共享),参数少、计算快,适合需要高效推理的场景(比如大模型部署)
- 最终输出形状:两者完全一样,只是内部计算方式不同
第四部分:测试代码(逐行拆解,运行过程全讲解)
这部分代码的作用:设置超参数、生成模拟输入数据、创建MHA和MQA实例、运行模型、打印输出形状,验证模型是否能正常运行。
batch = 10
num_head = 8
n_q, n_k, n_v = 2, 4, 4 # sequence 长度
dimension_q, dimension_k, dimension_v = 128, 128, 64 # embedding的长度
d_k, d_v, d_o = 16, 16, 8
q = torch.randn(batch, n_q, dimension_q)
k = torch.randn(batch, n_k, dimension_k)
v = torch.randn(batch, n_v, dimension_v)
mask = torch.full((batch, n_q, n_k), -np.inf)
mask = torch.triu(mask)
mha = MHA(num_head, dimension_k, dimension_v, d_k, d_v, d_o)
attention, output = mha(q, k, v, mask)
print(attention.size(), output.size())
mqa = MQA(num_head, dimension_k, dimension_v, d_k, d_v, d_o)
attention, output = mqa(q, k, v, mask)
print(attention.size(), output.size())
逐行拆解(每一行的目的、结果都讲清楚):
-
batch = 10:设置批次大小为10(一次处理10条数据) -
num_head = 8:设置注意力头数为8 -
n_q, n_k, n_v = 2, 4, 4:设置序列长度- n_q:Q的序列长度为2(每个Q有2个token)
- n_k:K的序列长度为4(每个K有4个token)
- n_v:V的序列长度为4(每个V有4个token),通常n_k = n_v
-
dimension_q, dimension_k, dimension_v = 128, 128, 64:设置输入特征维度- dimension_q:Q的特征维度为128
- dimension_k:K的特征维度为128(和Q一致)
- dimension_v:V的特征维度为64(和Q/K可以不一致)
-
d_k, d_v, d_o = 16, 16, 8:设置单头维度和最终输出维度- d_k:每个头的Q/K维度为16(8头×16=128,和dimension_k一致)
- d_v:每个头的V维度为16(8头×16=128)
- d_o:最终输出维度为8
-
q = torch.randn(batch, n_q, dimension_q):torch.randn(shape):PyTorch专属,生成形状为shape、服从“标准正态分布”(均值0,方差1)的随机张量- q的形状:[10, 2, 128](10批次,每个批次2个序列,每个序列128维特征)
- 目的:模拟真实的Q输入数据(实际使用时,这里会替换成真实的文本/图像嵌入数据)
-
k = torch.randn(batch, n_k, dimension_k):生成模拟K数据,形状 [10, 4, 128] -
v = torch.randn(batch, n_v, dimension_v):生成模拟V数据,形状 [10, 4, 64] -
mask = torch.full((batch, n_q, n_k), -np.inf):torch.full(shape, value):