基于前述学习文章,我们就可以搭建起 Transformer 中的 Encoder-Decoder 模块了。
子模块
多头自注意力(MultHeadAttention)
import torch.nn as nn
import torch
'''多头自注意力计算模块'''
class MultiHeadAttention(nn.Module):
def __init__(self, args: ModelArgs, is_causal=False):
# 构造函数
# args: 配置对象
super().__init__()
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
assert args.dim % args.n_heads == 0
# 每个头的维度,等于模型维度除以头的总数。
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x dim
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,
# 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接
self.wq = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
# 输出权重矩阵,维度为 dim x dim(head_dim = dim / n_heads)
self.wo = nn.Linear(self.n_heads * self.head_dim, args.dim, bias=False)
# 注意力的 dropout
self.attn_dropout = nn.Dropout(args.dropout)
# 残差连接的 dropout
self.resid_dropout = nn.Dropout(args.dropout)
self.is_causal = is_causal
# 创建一个上三角矩阵,用于遮蔽未来信息
# 注意,因为是多头注意力,Mask 矩阵比之前我们定义的多一个维度
if is_causal:
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = q.shape
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, dim) -> (B, T, dim)
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, dim // n_head),然后交换维度,变成 (B, n_head, T, dim // n_head)
# 因为在注意力计算中我们是取了后两个维度参与计算
# 为什么要先按B*T*n_head*C//n_head展开再互换1、2维度而不是直接按注意力输入展开,是因为view的展开方式是直接把输入全部排开,
# 然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# 注意力计算
# 计算 QK^T / sqrt(d_k),维度为 (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
# 掩码自注意力必须有注意力掩码
if self.is_causal:
assert hasattr(self, 'mask')
# 这里截取到序列长度,因为有些序列可能比 max_seq_len 短
scores = scores + self.mask[:, :, :seqlen, :seqlen]
# 计算 softmax,维度为 (B, nh, T, T)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 做 Dropout
scores = self.attn_dropout(scores)
# V * Score,维度为(B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
output = torch.matmul(scores, xv)
# 恢复时间维度并合并头。
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, dim // n_head),再拼接成 (B, T, n_head * dim // n_head)
# contiguous 函数用于重新开辟一块新内存存储,因为Pytorch设置先transpose再view会报错,
# 因为view直接基于底层存储得到,然而transpose并不会改变底层存储,因此需要额外存储
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 最终投影回残差流。
output = self.wo(output)
output = self.resid_dropout(output)
return output
前馈神经网络(FNN)
class MLP(nn.Module):
'''前馈神经网络'''
def __init__(self, dim: int, hidden_dim: int, dropout: float):
super().__init__()
# 定义第一层线性变换,从输入维度到隐藏维度
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
# 定义第二层线性变换,从隐藏维度到输入维度
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
# 定义dropout层,用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 前向传播函数
# 首先,输入x通过第一层线性变换和RELU激活函数
# 最后,通过第二层线性变换和dropout层
return self.dropout(self.w2(F.relu(self.w1(x))))
层归一化
class LayerNorm(nn.Module):
''' Layer Norm 层'''
def __init__(self, features, eps=1e-6):
super().__init__()
# 线性矩阵做映射
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
# 在统计每个样本所有维度的值,求均值和方差
mean = x.mean(-1, keepdim=True) # mean: [bsz, max_len, 1]
std = x.std(-1, keepdim=True) # std: [bsz, max_len, 1]
# 注意这里也在最后一个维度发生了广播
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
残差连接
# 注意力计算
h = x + self.attention.forward(self.attention_norm(x))
# 经过前馈神经网络
out = h + self.feed_forward.forward(self.fnn_norm(h))
编码器
Transformer 中的 Encoder 模块由 N 个 EncoderLayer 组成,每个 EncoderLayer 包括一个 Attention 层、一个 MLP 层。
class EncoderLayer(nn.Module):
'''Encoder层'''
def __init__(self, args):
super().__init__()
# 一个 Layer 中有两个 LayerNorm,分别在 Attention 之前和 MLP 之前
self.attention_norm = LayerNorm(args.n_embd)
# Encoder 不需要掩码,传入 is_causal=False
self.attention = MultiHeadAttention(args, is_causal=False)
self.fnn_norm = LayerNorm(args.n_embd)
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
def forward(self, x):
# Layer Norm
norm_x = self.attention_norm(x)
# 自注意力
h = x + self.attention.forward(norm_x, norm_x, norm_x)
# 经过前馈神经网络
out = h + self.feed_forward.forward(self.fnn_norm(h))
return out
class Encoder(nn.Module):
'''Encoder 块'''
def __init__(self, args):
super(Encoder, self).__init__()
# 一个 Encoder 由 N 个 Encoder Layer 组成
self.layers = nn.ModuleList([EncoderLayer(args) for _ in range(args.n_layer)])
self.norm = LayerNorm(args.n_embd)
def forward(self, x):
"分别通过 N 层 Encoder Layer"
for layer in self.layers:
x = layer(x)
return self.norm(x)
解码器
Transformer 中的 Decoder 模块由 N 个 DecoderLayer 组成。DecoderLayer 由 2 个注意力层和一个神经网络层组成,而且第 1 个注意力层是掩码自注意力层,第 2 个注意力层是一个多头注意力层,该层使用第一个注意力层的输出作为 query,使用 Encoder 的输出作为 key 和 value,以此来计算注意力分数;然后,再经过前馈神经网络。
class DecoderLayer(nn.Module):
'''解码层'''
def __init__(self, args):
super().__init__()
# 一个 Layer 中有三个 LayerNorm,分别在 Mask Attention 之前、Self Attention 之前和 MLP 之前
self.attention_norm_1 = LayerNorm(args.n_embd)
# Decoder 的第一个部分是 Mask Attention,传入 is_causal=True
self.mask_attention = MultiHeadAttention(args, is_causal=True)
self.attention_norm_2 = LayerNorm(args.n_embd)
# Decoder 的第二个部分是 类似于 Encoder 的 Attention,传入 is_causal=False
self.attention = MultiHeadAttention(args, is_causal=False)
self.ffn_norm = LayerNorm(args.n_embd)
# 第三个部分是 MLP
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
def forward(self, x, enc_out):
# Layer Norm
norm_x = self.attention_norm_1(x)
# 掩码自注意力
x = x + self.mask_attention.forward(norm_x, norm_x, norm_x)
# 多头注意力
norm_x = self.attention_norm_2(x)
h = x + self.attention.forward(norm_x, enc_out, enc_out)
# 经过前馈神经网络
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Decoder(nn.Module):
'''解码器'''
def __init__(self, args):
super(Decoder, self).__init__()
# 一个 Decoder 由 N 个 Decoder Layer 组成
self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layer)])
self.norm = LayerNorm(args.n_embd)
def forward(self, x, enc_out):
"Pass the input (and mask) through each layer in turn."
for layer in self.layers:
x = layer(x, enc_out)
return self.norm(x)
学习内容(Datawhale 开源学习项目):Happy-LLM