前言
本文参考大神蓝斯诺特的代码实现简单结构的 LLAMA 模型,不仅复现了 LLAMA 的结构,并且在实现模型结构的过程中,将涉及到的技术点都进行了介绍,方便大家学习。
为了简单实现,这里将层数修改为 4 ,隐层维度修改为1024 。
LlamaRMSNorm
import math,torch
class LlamaRMSNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(1024))
def forward(self, x):
var = x.pow(2).mean(2, keepdim=True)
x = x * (var + 1e-5).rsqrt()
return self.weight * x
RMSNorm(Root Mean Square Normalization)的数学原理是一种归一化方法,它通过规范化输入的均方根(RMS)值来平衡特征的尺度。它既保留了归一化层的重要功能,又简化了计算流程,在性能和效率之间取得了很好的平衡,是现代大模型中的常用技术之一。RMSNorm 的核心公式如下:
- :输入的第 i 个特征值
- :输入张量的均方根(Root Mean Square),对输入进行 RMS 归一化后,可以将数据的整体幅度缩放到一个固定范围(理论上 RMS 接近 1)。
- ϵ:一个很小的常数,用于数值稳定性(防止分母为零)
- :可学习的缩放权重参数,用于调节每个特征的幅度。
计算均方根 RMS :
RMSNorm 优点:
- RMSNorm 假设输入数据已经被中心化(均值为 0)。这种假设在很多模型(如深度学习中的 Transformer 等)中普遍成立。
- RMSNorm 更简单、更高效,仅依赖均方值,而不需要计算均值和方差,降低了计算复杂度,在性能表现与 LayerNorm 接近的情况下,减少了计算开销,尤其是当输入的维度较高时(如 1024 或更多)。
- RMSNorm 的主要目的是解决特征的尺度不平衡问题,它对特征值进行归一化,保证模型的输入和输出在训练过程中稳定,避免梯度消失或爆炸。
- 通过可学习的参数 g,它还能调整每个通道的幅度,从而增强模型的表达能力。
RoPE
# 旋转位置编码生成
@torch.no_grad()
def llama_rotary_embedding(length):
#这一部分计算了一个频率范围,用于生成旋转的位置编码。inv_freq 是位置编码的逆频率,它决定了每个位置的周期性。计算 500000.0**inv_freq 是一种高频率的衰减方式,用于让较小的频率影响较远的位置。
inv_freq = torch.arange(0, 32, 2) / 32
inv_freq = 1. / (500000.0**inv_freq )
inv_freq = inv_freq.reshape(1, 16, 1)
# 位置 ID
position_ids = torch.arange(length).reshape(1, 1, -1).float()
# 通过矩阵乘法(matmul),将位置 ID 与逆频率(inv_freq)结合,计算出每个位置的频率
freqs = inv_freq.matmul(position_ids).transpose(1, 2)
# 表示每个位置的正余弦频率
emb = torch.cat((freqs, freqs), 2)
return emb.cos(), emb.sin()
# 应用旋转位置编码
def apply_rotary_pos_emb(x, cos, sin):
def rotate_half(x):
# 前 16 维的特征表示左半部分
left = x[..., :16]
# 后 16 维的特征表示右半部分的负数
right = -x[..., 16:]
return torch.cat((right, left), -1)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
x = (x * cos) + (rotate_half(x) * sin)
return x
# 用法示范
cos, sin = llama_rotary_embedding(length)
q = apply_rotary_pos_emb(q, cos ,sin)
k = apply_rotary_pos_emb(k, cos, sin)
这段代码实现了 RoPE(Rotary Positional Embedding,旋转位置编码),原作者苏神介绍“这是一种配合Attention机制能达到“绝对位置编码的方式实现相对位置编码”的设计。而也正因为这种设计,它还是目前唯一一种可用于线性Attention的相对位置编码。”这是一种用于自然语言处理模型中表示位置信息的技术,特别是在处理序列数据时非常有用。RoPE 的本质目的是通过引入旋转的方式,在 Transformer 模型中将每个位置的位置信息与模型的特征相结合,如 q、k、v 这三个常见的特征,Attentio 的核心运算是内积,所以它们的内积结果带有相对位置信息,这种方式能够改善传统位置编码方法(如绝对位置编码)对序列长距离依赖的处理能力。使得每个位置的编码不仅包含其绝对位置的信息,还能够反映出该位置与其他位置的相对关系。
优点:
-
不仅包含其绝对位置的信息,还能够反映出该位置与其他位置的相对关系
-
尤其对于长序列数据(如长文本、长时间序列)表现尤为突出
-
增强模型在长序列上的泛化能力
llama_rotary_embedding(length):
这个函数用于生成旋转位置编码,主要分为以下几步:
-
计算反频率(inv_freq) :使用一个固定的公式生成反频率。
inv_freq是通过一个固定的常数和频率序列来计算的,目的是生成不同位置的频率信息。 -
计算位置ID(position_ids) :位置ID是表示每个位置在序列中的相对位置。
position_ids是一个从 0 到length-1的序列,表示每个位置的编号。 -
计算旋转频率:通过将位置ID和反频率相乘,得到旋转频率(
freqs)。这个频率信息被用于后续的位置编码。 -
返回余弦和正弦:最终生成旋转位置编码的余弦(cos)和正弦(sin)值,作为旋转位置编码的表示。
apply_rotary_pos_emb(x, cos, sin):
这个函数将旋转位置编码应用到输入张量 x 上。具体步骤如下:
-
rotate_half(x):这个操作将输入x的前半部分与后半部分交换。它是旋转位置编码的核心操作,使得每个位置的表示能够以旋转的方式进行编码。 -
应用旋转位置编码:通过将输入
x和cos、sin的余弦和正弦值结合,按照旋转的方式将位置信息加到输入张量x上。具体来说,是将x的每一部分与cos和sin相乘,然后将交换过的部分与sin相乘并加上。
LlamaMLP
class LlamaMLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = torch.nn.Linear(1024, 14336, bias=False)
self.up_proj = torch.nn.Linear(1024, 14336, bias=False)
self.down_proj = torch.nn.Linear(14336,1024, bias=False)
self.act_fn = torch.nn.SiLU()
def forward(self, x):
# 这个部分的作用是生成一个“门控”信号,控制信息的流动
left = self.act_fn(self.gate_proj(x))
# 这部分的作用是从原始输入中提取更多的特征
right = self.up_proj(x)
# 这个输出是经过门控机制调节后的特征,包含了原始输入的丰富特征信息,并且通过激活函数和逐元素相乘的操作来控制信息的流动
return self.down_proj(left * right)
这个 LlamaMLP 类实现了一个非常经典的多层感知机结构,但采用了一些创新设计,能够处理复杂的模式和非线性关系,提升模型的性能,尤其是在对模型容量和表达能力有较高要求时,作为 Transformer 或其他类似架构的中间层。主要有以下三个部分内容:
门控机制:left 向量是通过激活的 gate_proj(x) 获得的,代表了输入 x 在高维空间中的非线性变换。right 向量是通过 up_proj(x) 获得的,代表了输入 x 在高维空间中的线性映射。 通过逐元素相乘的方式将这两个向量结合,形成了一种门控机制,控制了信息流的传递。
高维映射和降维:gate_proj 和 up_proj 都将输入映射到一个较高的维度(14336),然后 down_proj 将其降回到原始维度(1024)。这种设计通过先在高维空间中计算信息流,再将其压缩回较低的维度,可以帮助模型更好地捕捉非线性和复杂的关系。
激活函数(SiLU) :SiLU(Sigmoid Linear Unit)激活函数,也被称为Swish函数,由谷歌大脑的研究者提出。它是一种自门控的激活函数,意味着输出的值是由输入值本身决定是否需要被放大或抑制。使用 SiLU 激活函数而非传统的 ReLU 或 Tanh,是为了提高非线性变换的能力。SiLU函数在一些深度学习模型中表现出了良好的性能,尤其是在自然语言处理(NLP)和计算机视觉任务中。它被认为是 ReLU 和 Leaky ReLU 的一种替代品。
SiLU函数的公式如下:
其中, 是 sigmoid 函数, 是一个可学习的参数(在实际应用中通常设为1以简化计算), 是输入值。当 设置为 1 时,SiLU 函数简化为:
如图所示:
Sigmoid 函数的定义是:
SiLU函数的特点包括:
- 平滑性:由于SiLU函数是连续可微的,因此它在反向传播过程中可以提供平滑的梯度。
- 非单调性:SiLU函数在输入值较小的时候接近于0,在输入值较大的时候接近于1,这意味着它可以根据输入值的大小自适应地调整输出。
- 参数效率:SiLU函数不需要额外的参数,除了输入值本身,这使得它在参数效率上具有优势。
- 自门控:SiLU函数可以根据输入值自动决定是否需要激活,这有助于减少信息在深层网络中的丢失。
LlamaAttention
# 用于确保在自注意力机制中每个 token 只能关注到它之前的 token(即“因果”关系)。该函数通常用于训练中的自回归任务,比如生成任务。
def get_causal_mask(attention_mask):
b, length = attention_mask.shape
min_value = -1e15
# 是一个上三角矩阵,表示一个 token 只能关注之前的 token,防止未来的信息泄漏
causal_mask = torch.full((length, length), min_value).triu(diagonal=1)
causal_mask = causal_mask.reshape(1, 1, length, length).repeat(b, 1, 1, 1)
causal_mask = causal_mask.to(attention_mask.device)
# 将原本 attention_mask 中为 0 的位置也填充为 min_value ,表示这些位置不需要关注
mask = attention_mask.reshape(b, 1, 1, length) == 0
causal_mask = causal_mask.masked_fill(mask, min_value)
return causal_mask
def repeat_kv(x):
shape = list(x.shape)
shape[1] *= 4
return x.unsqueeze(2).repeat(1,1,4,1,1).reshape(shape)
class LlamaAttention(torch.nn.Module):
def __init__(self):
super().__init__()
self.q_proj = torch.nn.Linear(1024, 1024, bias=False)
self.k_proj = torch.nn.Linear(1024, 256, bias=False)
self.v_proj = torch.nn.Linear(1024, 256, bias=False)
self.o_proj = torch.nn.Linear(1024, 1024, bias=False)
def forward(self, hidden_states, attention_mask):
b, length, _ = hidden_states.shape
# qkv 线性转换,并拆分多头
q = self.q_proj(hidden_states).reshape(b, length, 32, 32).transpose(1,2)
k = self.k_proj(hidden_states).reshape(b, length, 8, 32).transpose(1,2)
v = self.v_proj(hidden_states).reshape(b, length, 8, 32).transpose(1,2)
# 加入 RoPE
cos, sin = llama_rotary_embedding(length)
cos, sin = cos.to(hidden_states.device), sin.to(hidden_states.device)
q = apply_rotary_pos_emb(q, cos ,sin)
k = apply_rotary_pos_emb(k, cos, sin)
# 复制多份方便计算
k = repeat_kv(k)
v = repeat_kv(v)
# 计算注意力
attn = q.matmul(k.transpose(2,3)) / math.sqrt(32)
attention_mask = get_causal_mask(attention_mask)
attn = (attn + attention_mask).softmax(3)
attn = attn.matmul(v)
# 合并多头
attn = attn.transpose(1,2).reshape(b, length, 1024)
attn = self.o_proj(attn)
return attn
这里就是常见的注意力机制实现,有两点不同的是:
- q、k、v 的多头数量不同
- 对 q、k、v 融入旋转位置编码
这里需要解释的是第二点,在多头注意力机制中,q(查询),k(键),v(值)的 head 数量不同并不常见,但它是可以这样设计的,具体原因通常与模型目标或优化目标相关。以下是一些可能的原因和解释:
- 不同的表示能力
- 查询(q) :一般来说,查询向量用于表示当前目标或输入对其他输入的“关注”程度。可能需要更多的维度来捕捉更细粒度的信息,尤其是当问题需要较高的表达能力时。因此,
q的维度可能需要更多的头数(32)来表示更多的“关注”模式。 - 键(k)与值(v) :相比之下,键和值通常用于对输入进行编码,并传递给注意力机制。它们的头数(这里是 8)可能比查询更少,因为它们的作用是提供信息,而不是直接控制注意力分布。减少键和值的头数可能是为了降低计算量,特别是在训练大型模型时,减少冗余的表示。
- 计算复杂度优化
- 如果
q,k, 和v的维度数相同,通常它们会被拆分为相同数量的头。在这种设计下,头数不相同可以在不同头之间进行计算量的分配,平衡性能和计算效率。通过减少k和v的头数,可以降低整体计算的复杂度,尤其是在计算键值对时(比如在多头注意力中的点积计算)。 - 比如,
q的头数多(32)表示它需要更多的细粒度查询,而k和v的头数少(8)意味着它们更注重对全局信息的编码和传递。
LlamaDecoderLayer
class LlamaDecoderLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = LlamaAttention()
self.mlp = LlamaMLP()
self.input_layernorm = LlamaRMSNorm()
self.post_attention_layernorm = LlamaRMSNorm()
def forward(self, hidden_states, attention_mask):
res = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + res
res = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) + res
return hidden_states
该 LlamaDecoderLayer 的实现遵循了典型 Transformer 解码器层的设计,包含以下几个关键组件:
- 自注意力层:利用
self_attn计算输入序列的自注意力表示,能够捕捉序列中各位置之间的依赖关系。 - 前馈神经网络(MLP) :通过
mlp层进一步增强模型的表达能力,处理经过自注意力计算后的隐藏状态。 - 层归一化(RMSNorm) :通过
LlamaRMSNorm对每一层的输入进行标准化,以提高训练稳定性,防止梯度爆炸或消失。 - 残差连接:每一层(自注意力层和 MLP 层)都采用了残差连接,使得输入可以直接绕过每一层传递,帮助加速收敛,并且改善深度网络的训练效果。
LlamaModel
class LlamaModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = torch.nn.Embedding(128256, 1024, None)
self.layers = torch.nn.ModuleList([LlamaDecoderLayer() for _ in range(4)])
self.norm = LlamaRMSNorm()
def forward(self, input_ids, attention_mask):
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm(hidden_states)
return hidden_states
可以很明显看出来,和传统的 transformer 的嵌入层不同,这里只有一个简单的 embed_tokens ,位置嵌入已经放入了后面的注意力机制计算中和 q、k、v 进行融合了。
LlamaForCausalLM
class LlamaForCausalLM(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = LlamaModel()
self.lm_head = torch.nn.Linear(1024, 128256, bias=False)
def forward(self, input_ids, attention_mask, labels=None):
logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = self.lm_head(logits)
loss = None
if labels is not None:
shift_logits = logits[:, :-1].reshape(-1, 128256)
shift_labels = labels[:, 1:].reshape(-1)
loss = torch.nn.functional.cross_entropy(shift_logits, shift_labels)
return loss, logits
整个 llama 因果模型,其实就是在 LlamaModel 之上加入了一层 lm_head ,其实就是一个全连接层,用来映射到全词库。
比对
from transformers import LlamaConfig, LlamaForCausalLM as LM
config = "{'vocab_size': 128256, 'max_position_embeddings': 8192, 'hidden_size': 4096, 'intermediate_size': 14336, 'num_hidden_layers': 32, 'num_attention_heads': 32, 'num_key_value_heads': 8, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-05, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 500000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'bfloat16', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['LlamaForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 128000, 'pad_token_id': None, 'eos_token_id': 128001, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'transformers_version': '4.38.2', 'model_type': 'llama'}"
config = LlamaConfig.from_dict(eval(config))
config.hidden_size = 1024
config.num_hidden_layers = 4
model1 = LM(config)
model2 = LlamaForCausalLM()
model2.load_state_dict(model1.state_dict())
input = {
'input_ids': torch.randint(100, 50000, [4, 125]),
'attention_mask': torch.ones(4, 125).long(),
'labels': torch.randint(100, 50000, [4, 125])
}
input['attention_mask'][:, 120:] = 0
out = model1(**input)
loss, logits = model2(**input)
print(out.loss, out.logits.shape)
print(loss, logits.shape)
out.loss == loss, (out.logits == logits).all()
tensor(11.9773, grad_fn=<NllLossBackward0>) torch.Size([4, 125, 128256])
tensor(11.9773, grad_fn=<NllLossBackward0>) torch.Size([4, 125, 128256])
通过加载原始 LLAMA 模型,并将两个重要的参数修改成和我们一样,其他保持不变,然后和我们实现的模型进行比对,发现计算的结果完全一样,说明我们的模型结构没有问题。