听说你还没读过大模型源码? - baichuan

388 阅读2分钟

前言

本文是对 Baichuan-7B 的模型源码进行剖析,来加强自己的 LLM 代码能力。baichuan-7B 与 Baichuan-13B 有一些区别,这里介绍以 7B 为主,13B 的类似,就不赘述了。

详细的代码注释参见:源码阅读

整体描述

image.png

整体划分

我们可以大致将整体的模型划分为三个部分:

  • 旋转位置编码的实现
  • 核心 DecoderLayer 的实现
  • 各个上层 Model 的实现。 此外,还需要注意的是 attention_mask 的变换

本文对除旋转位置编码部分外的所有部分都进行了详细注释,并删减了一些冗余代码。此外还提供了原始的baichuan代码来进行调试学习。

RotaryEmbedding 与 apply_rotary_pos_emb

旋转位置编码的实现,这块比较复杂,单独写一篇来讲述。

RMSNorm与MLP

  • RMSNorm: RMSNorm 的实现,类似 LayerNormalization
  • MLP:全连接层实现,主要包括激活函数和原始的全连接层,baichuan中使用的激活函数为 SwiGLU。 这两个没啥好说的,基本上就是公式的实现。

Attention

  • Attention:Q,K,V 多头注意力机制的实现,这里面需要注意的是, attention_mask 的使用。 这部分没啥好说的,常规Transformer 中 多头自注意力机制的实现。具体看代码注释吧。

值得注意的是,旋转位置编码是在 Attention 计算时融入的,由于比较复杂,这部分后面会单独写一篇介绍。

DecoderLayer

单层Transformer 的实现,主要是如何将:残差,Normalization,多头注意力机制,全连接线性层 这几个大块组成一个 stack。

BaichuanModel

这部分主要包括:

  • 输入的处理: attention_mask ,position_ids 的组成
  • 多层 DecoderLayer的计算 这个也比较简单,主要是理清楚 position_ids, 和 attention_mask 的计算即可。

BaiChuanForCausalLM

BaiChuanForCausalLM:在 BaichuanModel 的基础上,加入了一个线性层。也比较简单。这里值得注意的是,损失的计算部分。