前言
本文是对 Baichuan-7B 的模型源码进行剖析,来加强自己的 LLM 代码能力。baichuan-7B 与 Baichuan-13B 有一些区别,这里介绍以 7B 为主,13B 的类似,就不赘述了。
详细的代码注释参见:源码阅读
整体描述
整体划分
我们可以大致将整体的模型划分为三个部分:
- 旋转位置编码的实现
- 核心 DecoderLayer 的实现
- 各个上层 Model 的实现。 此外,还需要注意的是 attention_mask 的变换。
本文对除旋转位置编码部分外的所有部分都进行了详细注释,并删减了一些冗余代码。此外还提供了原始的baichuan代码来进行调试学习。
RotaryEmbedding 与 apply_rotary_pos_emb
旋转位置编码的实现,这块比较复杂,单独写一篇来讲述。
RMSNorm与MLP
- RMSNorm: RMSNorm 的实现,类似 LayerNormalization
- MLP:全连接层实现,主要包括激活函数和原始的全连接层,baichuan中使用的激活函数为 SwiGLU。 这两个没啥好说的,基本上就是公式的实现。
Attention
- Attention:Q,K,V 多头注意力机制的实现,这里面需要注意的是, attention_mask 的使用。 这部分没啥好说的,常规Transformer 中 多头自注意力机制的实现。具体看代码注释吧。
值得注意的是,旋转位置编码是在 Attention 计算时融入的,由于比较复杂,这部分后面会单独写一篇介绍。
DecoderLayer
单层Transformer 的实现,主要是如何将:残差,Normalization,多头注意力机制,全连接线性层 这几个大块组成一个 stack。
BaichuanModel
这部分主要包括:
- 输入的处理: attention_mask ,position_ids 的组成
- 多层 DecoderLayer的计算 这个也比较简单,主要是理清楚 position_ids, 和 attention_mask 的计算即可。
BaiChuanForCausalLM
BaiChuanForCausalLM:在 BaichuanModel 的基础上,加入了一个线性层。也比较简单。这里值得注意的是,损失的计算部分。