引言
继续我们的“从零开始训练自己的小参数量大模型”系列。
在上几篇文章中,我们精心准备了数据集,相当于为我们的模型备好了“精神食粮”,而且训练好了我们的词库,并且使用PyTorch的API搭建好了我们的基础模型。现在,万事俱备,只欠一个强大的“炼丹炉”——也就是我们的训练脚本。
这个脚本是整个训练过程的“中央控制室”,它负责加载数据、初始化模型、执行训练循环、监控性能、保存进度……。一个健壮、高效的训练脚本是决定我们能否顺利“炼出好丹”的关键。
今天的文章,我们将逐行解析这份核心的训练代码。别担心,虽然代码看起来很长,但我们会把它拆解成一个个清晰的模块,让大家彻底明白每一部分的作用。让我们开始吧!
完整的代码可以参考文末附录的链接。
超参数与全局配置
在任何一个大工程开始前,我们都需要一张蓝图。在模型训练中,这张蓝图就是我们的超参数(Hyperparameters) 和配置。它们定义了训练的方方面面,从模型的大小到训练的方式。
# 1. 超参数和配置
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOKENIZER_FILE = "bpe_tokenizer_from_arrow/tokenizer.json"
TRAIN_DATASET_PATH = "split_openwebtext_custom_sample/train"
VAL_DATASET_PATH = "split_openwebtext_custom_sample/validation"
MAX_SEQ_LENGTH = 512
# 大家按照自己的显卡来判断,BATCH_SIZE=64是会超出16GB显存的,大概是16500MB左右,如果大家显卡只有16GB,可以适当减低BATCH_SIZE
BATCH_SIZE = 64
NUM_EPOCHS = 2
CLIP_GRAD_NORM = 1.0
GRADIENT_ACCUMULATION_STEPS = 8
WARMUP_STEPS = 4000 // GRADIENT_ACCUMULATION_STEPS
# SCHEDULER_TYPE = "cosine" # 使用余弦退火训练
SCHEDULER_TYPE = 'constant' # 使用固定学习率训练
LEARNING_RATE = 5e-4 # 余弦退火方案训练时的起始学习率
CONSTANT_LEARNING_RATE = 1.5e-4 # 固定学习率训练时的学习率
# 模型参数
EMBED_DIM = 512 # 词向量维度
NUM_HEADS = 8 # 注意力头的数量
FF_DIM = EMBED_DIM * 4 # 前馈网络中间层维度
NUM_LAYERS = 8 # Transformer层数
DROPOUT = 0.1
TIE_WEIGHTS = True
# ... 其他参数 ...
# BF16/AMP Configuration
USE_AMP = True
AMP_DTYPE = torch.bfloat16
USE_SCALER = True
-
路径与设备 (
DEVICE,..._PATH):配置了训练使用的设备(优先用CUDA显卡)和所需文件的路径。 -
核心训练参数 (
MAX_SEQ_LENGTH,BATCH_SIZE,...):MAX_SEQ_LENGTH:模型的“短期记忆”有多长。512表示模型一次最多处理512个token。BATCH_SIZE:每次送入显卡进行计算的样本数。64是一个比较大的“微批次”(micro-batch)。使用64大小的BATCH_SIZE显存占用大概是16500MB左右,如果大家显卡只有16GB,需要降低BATCH_SIZE的值GRADIENT_ACCUMULATION_STEPS:梯度累积步数,这是个“省显存大法”。它允许我们用较小的显存模拟出大批次的训练效果。这里设置为8,意味着有效批次大小(effective batch size)是64 * 8 = 512。模型会计算8次微批次的梯度,然后才统一更新一次参数。LEARNING_RATE:学习率。想象你在下山,学习率就是你每一步迈多大。步子太大容易“一步踩空”(错过最优点),步子太小则“天黑了也下不了山”(训练太慢)。CLIP_GRAD_NORM:梯度裁剪阈值。为了防止训练过程中梯度突然变得过大(梯度爆炸)导致训练不稳定,我们把梯度的范数(可以理解为梯度向量的长度)限制在1.0以内。
-
模型结构参数 (
EMBED_DIM,NUM_HEADS,...):这些数字共同定义了我们模型的“三围”。我们这里构建的是一个8层、512维度、8个头的“小”模型,这是一个非常适合学习和实验的尺寸。 -
效率与显存优化 (
USE_AMP,AMP_DTYPE,USE_SCALER):这是我们的“加速黑科技”——自动混合精度(AMP)。USE_AMP = True:开启这个功能。AMP_DTYPE = torch.bfloat16:使用bfloat16这种“半精度”浮点数进行大部分计算。- 一句话解释AMP:就像我们做数学题,大部分草稿(矩阵乘法)可以用不那么精确的数字快速计算,只有在关键步骤(梯度更新)才用高精度数字保证准确性。这样既能大幅提升训练速度,又能节省显存。
USE_SCALER=True开启梯度缩放器,是配合AMP防止数值下溢的关键组件。
训练的心脏 —— train_epoch 函数
这里是模型学习和成长的核心地带。一个epoch(一轮)的训练就是在这个函数里完成的。代码的精髓在于梯度累积的实现。
def train_epoch(model, dataloader, optimizer, ..., scaler, lr_scheduler):
model.train() # 1. 切换到“训练模式”
optimizer.zero_grad(set_to_none=True) # 2. 在循环开始前清空梯度
for batch_idx, batch in progress_bar:
# ... 数据准备 ...
# 3. 开启混合精度计算上下文
with autocast(enabled=USE_AMP, dtype=AMP_DTYPE, ...):
logits = model(model_input) # 4. 模型进行预测
# 5. 计算损失,并根据累积步数进行缩放
loss = criterion(logits.view(-1, ...), targets.view(-1))
loss = loss / GRADIENT_ACCUMULATION_STEPS
# 6. 反向传播(累积梯度)
scaler.scale(loss).backward()
# 7. 达到累积步数,执行一次参数更新
if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
scaler.unscale_(optimizer) # 将梯度缩放回来以便裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM) # 裁剪梯度
scaler.step(optimizer) # 更新模型参数
scaler.update() # 更新缩放器
lr_scheduler.step() # 8. 更新学习率
optimizer.zero_grad(set_to_none=True) # 清空梯度,为下一轮累积做准备
# ... 日志记录与模型保存 ...
model.train():切换模型到训练状态,这会启用像Dropout这样的训练特有层。optimizer.zero_grad():在循环开始前,先把梯度清零。注意,在梯度累积逻辑中,它也会在每次参数更新后被调用。with autocast(...):启用混合精度,PyTorch会自动将计算转换为bfloat16来提速。model(model_input):把数据喂给模型,让它做出预测。loss = loss / GRADIENT_ACCUMULATION_STEPS:计算损失后,将其除以累积步数。这是为了保证累积了多次的梯度在量级上与没有累积时大致相当。scaler.scale(loss).backward():这是整个训练过程的“灵魂”!它根据缩放后的损失,计算出模型每个参数应该如何调整(即计算梯度)。由于我们没有清空梯度,所以每次调用,新的梯度会累加到旧的梯度上。if ... % GRADIENT_ACCUMULATION_STEPS == 0:这是梯度累积的核心判断。只有当处理了足够数量的“微批次”后,才执行真正的“知错就改”步骤:scaler.unscale_和clip_grad_norm_:在更新前,先将梯度恢复原始大小,并进行裁剪,防止模型“用力过猛”。scaler.step(optimizer):根据累积了多次的梯度,真正地去更新模型的每一个参数。这一步相当于用一个大批次的数据在训练。
lr_scheduler.step():我们不希望学习率一成不变。lr_scheduler是学习率的“配速员”,它会根据预设策略(如预热+余弦退火)调整学习率。关键是,它在参数更新(optimizer.step)之后调用,而不是在每个微批次后。
定期“考试” —— validate_epoch 函数
光闷头学习不行,我们得定期“考试”来检验模型学习成果。validate_epoch函数就扮演了这个角色。
def validate_epoch(model, dataloader, ...):
model.eval() # 1. 切换到“评估模式”
with torch.no_grad(): # 2. 关闭梯度计算
for batch_idx, batch in progress_bar:
# ... 数据准备 ...
with autocast(...): # 同样使用混合精度进行推理
logits = model(model_input)
loss = criterion(...) # 计算损失
perplexity = math.exp(avg_loss) # 3. 计算困惑度
# ... 记录日志 ...
model.eval():切换模型到验证状态,这会关闭Dropout等层,确保评估结果的稳定性。with torch.no_grad():验证时不需要学习,所以我们关闭梯度计算。这能节省大量显存和计算资源,让评估跑得更快。- 困惑度(Perplexity):这是评估语言模型最常用的指标之一。我们可以直观地理解为:模型对于接下来要出现哪个词,有多“困惑”。困惑度越低,说明模型对文本的预测能力越强,学得越好。
断点续训与模型保存
训练大模型动辄几天甚至几周,如果中途断电或者程序崩溃,一切从头再来,那绝对是欲哭无泪。所以,断点续训(Resume Training) 和 模型保存(Checkpointing) 是我们的“救生索”。我们的脚本实现了非常完善的机制。
1. 断点续训
在main函数中,我们有这样一段逻辑:
# 设置为你想恢复的checkpoint的路径
RESUME_CHECKPOINT_PATH: Optional[str] = None
# ... 在 main 函数里 ...
if RESUME_CHECKPOINT_PATH and os.path.exists(RESUME_CHECKPOINT_PATH):
print(f"Resuming training from checkpoint: {RESUME_CHECKPOINT_PATH}")
checkpoint = torch.load(RESUME_CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict']) # 恢复模型
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 恢复优化器
if USE_SCALER:
scaler.load_state_dict(checkpoint['scaler_state_dict']) # 恢复混合精度缩放器
global_step = checkpoint['global_step']
start_epoch = checkpoint['epoch'] + 1
# 将学习率调度器快进到正确的步数
for _ in range(global_step):
lr_scheduler.step()
# ...
当我们设置了RESUME_CHECKPOINT_PATH,脚本启动时会先检查这个文件。如果存在,它会像读取“游戏存档”一样,恢复所有关键状态:
- 模型的权重 (
model_state_dict)。代码还智能地处理了模型是否被DataParallel等包装过的情况,自动去掉module.前缀,非常稳健。 - 优化器的状态 (
optimizer_state_dict,这很重要,因为它包含了AdamW等优化器的动量信息)。 - 混合精度缩放器的状态 (
scaler_state_dict)。 - 训练进度 (
global_step和epoch)。 - 学习率调度器状态:通过一个循环,将调度器“快进”到
global_step,确保学习率曲线能够完美衔接。
这样,训练就能无缝衔接,从上次中断的地方继续,分毫不差!
2. 定期保存与滚动删除
训练过程中,我们不能只等结束了再保存。
# 配置
SAVE_INTERVAL = 10000 // GRADIENT_ACCUMULATION_STEPS # 每多少个 optimizer_step 保存一次
MAX_CHECKPOINT_TO_KEEP = 99 # 最多保留几个
# ... 在 train_epoch 函数里,参数更新后 ...
if global_step > 0 and global_step % SAVE_INTERVAL == 0:
# ... 保存 checkpoint ...
torch.save(save_dict, checkpoint_path)
# ... 管理队列,自动删除旧的checkpoint ...
if saved_checkpoints_queue.maxlen is not None:
if len(saved_checkpoints_queue) == saved_checkpoints_queue.maxlen:
path_to_remove_on_disk = saved_checkpoints_queue.popleft()
os.remove(path_to_remove_on_disk)
saved_checkpoints_queue.append(checkpoint_path)
SAVE_INTERVAL:我们设置每训练一定步数(注意是参数更新的步数,不是微批次数),就自动存一次档。MAX_CHECKPOINT_TO_KEEP和deque:模型存档非常占硬盘空间。我们不希望无限保存下去。这里使用了一个deque(双端队列)实现滚动保存。它永远只保留最近的MAX_CHECKPOINT_TO_KEEP个存档。每当一个新的存档产生,最老的那一个就会被自动删除。这极大地节省了我们的硬盘空间。- 恢复时自动填充队列:更棒的是,如果从断点恢复,脚本会扫描检查点目录,将已有的最近的检查点文件加载到这个管理队列中,避免了恢复后队列为空的尴尬。
- 保存最佳模型:除了定期保存,我们还会在每次验证后(由
SAVE_INTERVAL触发),如果发现模型的验证损失创下新低,就单独保存一个名为best_model.pt的文件。这确保我们总能拿到训练过程中表现最好的那个模型!
启动引擎 —— main 函数全览
main函数就是我们的总指挥。它按照顺序,一步步地把所有模块串联起来:
- 加载“三件套”:加载分词器(Tokenizer)、数据集(Dataset)和数据加载器(DataLoader)。
DataLoader使用了pin_memory和persistent_workers等优化参数,可以提升数据加载效率。 - 初始化“核心组件”:创建模型实例、损失函数(
CrossEntropyLoss)、优化器(AdamW)和学习率调度器。值得注意的是,学习率调度器(lr_scheduler)的总步数MAX_TRAIN_STEPS是根据数据集大小、Epoch数和梯度累积步数精确计算出来的,这让学习率的衰减过程更加合理。 - 执行“断点续训”检查:如上所述,尝试从存档恢复,并初始化检查点管理队列。
- 启动“训练-验证”主循环:
for epoch in range(start_epoch, NUM_EPOCHS): # 训练一轮,期间会按SAVE_INTERVAL进行保存和验证 train_loss, global_step, best_val_loss = train_epoch(...) # 每个epoch结束后,进行一次完整的验证 val_loss, val_perplexity = validate_epoch(...) # 检查是否是最佳模型并保存 if val_loss < best_val_loss: # ... save best_model.pt ... - 结束:训练完成后,关闭日志记录器,打印结束信息。
总结与展望
到这里,我们学习了“炼丹炉”的搭建。现在,我们不仅知道代码在做什么,更重要的是,理解了为什么要这么做。这份脚本凝聚了许多大型模型训练的最佳实践。
我们掌握了:
- ✅ 如何通过超参数配置来定义我们的训练任务。
- ✅ 使用梯度累积,在有限的显存下实现大批次训练。
- ✅ 混合精度(AMP)与梯度缩放(GradScaler)如何为我们提速省显存。
- ✅ 验证集和困惑度如何客观地评价模型性能。
- ✅ 如何实现一个非常健壮的断点续训和智能检查点管理策略。
至此,我们万事具备,是时候让显卡开始“燃烧”了!在下一篇文章中,我们将把所有部分整合起来,正式开启训练,并观察模型的学习曲线,见证一个语言模型从“随机乱说”到“言之有物”的诞生过程。敬请期待!
关注我的公众号不走丢
附录
GitHub链接:github.com/JimmysAIPG/…