从零开始训练大模型:搭建炼丹炉,详解训练脚本的每一行代码

233 阅读12分钟

引言

继续我们的“从零开始训练自己的小参数量大模型”系列。

在上几篇文章中,我们精心准备了数据集,相当于为我们的模型备好了“精神食粮”,而且训练好了我们的词库,并且使用PyTorch的API搭建好了我们的基础模型。现在,万事俱备,只欠一个强大的“炼丹炉”——也就是我们的训练脚本

这个脚本是整个训练过程的“中央控制室”,它负责加载数据、初始化模型、执行训练循环、监控性能、保存进度……。一个健壮、高效的训练脚本是决定我们能否顺利“炼出好丹”的关键。

今天的文章,我们将逐行解析这份核心的训练代码。别担心,虽然代码看起来很长,但我们会把它拆解成一个个清晰的模块,让大家彻底明白每一部分的作用。让我们开始吧!

完整的代码可以参考文末附录的链接。

超参数与全局配置

在任何一个大工程开始前,我们都需要一张蓝图。在模型训练中,这张蓝图就是我们的超参数(Hyperparameters) 和配置。它们定义了训练的方方面面,从模型的大小到训练的方式。

# 1. 超参数和配置
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOKENIZER_FILE = "bpe_tokenizer_from_arrow/tokenizer.json"
TRAIN_DATASET_PATH = "split_openwebtext_custom_sample/train"
VAL_DATASET_PATH = "split_openwebtext_custom_sample/validation"

MAX_SEQ_LENGTH = 512
# 大家按照自己的显卡来判断,BATCH_SIZE=64是会超出16GB显存的,大概是16500MB左右,如果大家显卡只有16GB,可以适当减低BATCH_SIZE
BATCH_SIZE = 64
NUM_EPOCHS = 2
CLIP_GRAD_NORM = 1.0
GRADIENT_ACCUMULATION_STEPS = 8
WARMUP_STEPS = 4000 // GRADIENT_ACCUMULATION_STEPS

# SCHEDULER_TYPE = "cosine"  # 使用余弦退火训练
SCHEDULER_TYPE = 'constant' # 使用固定学习率训练
LEARNING_RATE = 5e-4 # 余弦退火方案训练时的起始学习率
CONSTANT_LEARNING_RATE = 1.5e-4 # 固定学习率训练时的学习率

# 模型参数
EMBED_DIM = 512      # 词向量维度
NUM_HEADS = 8        # 注意力头的数量
FF_DIM = EMBED_DIM * 4 # 前馈网络中间层维度
NUM_LAYERS = 8       # Transformer层数
DROPOUT = 0.1
TIE_WEIGHTS = True

# ... 其他参数 ...

# BF16/AMP Configuration
USE_AMP = True
AMP_DTYPE = torch.bfloat16
USE_SCALER = True
  • 路径与设备 (DEVICE, ..._PATH):配置了训练使用的设备(优先用CUDA显卡)和所需文件的路径。

  • 核心训练参数 (MAX_SEQ_LENGTH, BATCH_SIZE, ...)

    • MAX_SEQ_LENGTH:模型的“短期记忆”有多长。512表示模型一次最多处理512个token。
    • BATCH_SIZE:每次送入显卡进行计算的样本数。64是一个比较大的“微批次”(micro-batch)。使用64大小的BATCH_SIZE显存占用大概是16500MB左右,如果大家显卡只有16GB,需要降低BATCH_SIZE的值
    • GRADIENT_ACCUMULATION_STEPS梯度累积步数,这是个“省显存大法”。它允许我们用较小的显存模拟出大批次的训练效果。这里设置为8,意味着有效批次大小(effective batch size)是 64 * 8 = 512。模型会计算8次微批次的梯度,然后才统一更新一次参数。
    • LEARNING_RATE:学习率。想象你在下山,学习率就是你每一步迈多大。步子太大容易“一步踩空”(错过最优点),步子太小则“天黑了也下不了山”(训练太慢)。
    • CLIP_GRAD_NORM:梯度裁剪阈值。为了防止训练过程中梯度突然变得过大(梯度爆炸)导致训练不稳定,我们把梯度的范数(可以理解为梯度向量的长度)限制在1.0以内。
  • 模型结构参数 (EMBED_DIM, NUM_HEADS, ...):这些数字共同定义了我们模型的“三围”。我们这里构建的是一个8层、512维度、8个头的“小”模型,这是一个非常适合学习和实验的尺寸。

  • 效率与显存优化 (USE_AMP, AMP_DTYPE, USE_SCALER):这是我们的“加速黑科技”——自动混合精度(AMP)

    • USE_AMP = True:开启这个功能。
    • AMP_DTYPE = torch.bfloat16:使用bfloat16这种“半精度”浮点数进行大部分计算。
    • 一句话解释AMP:就像我们做数学题,大部分草稿(矩阵乘法)可以用不那么精确的数字快速计算,只有在关键步骤(梯度更新)才用高精度数字保证准确性。这样既能大幅提升训练速度,又能节省显存。USE_SCALER=True 开启梯度缩放器,是配合AMP防止数值下溢的关键组件。

训练的心脏 —— train_epoch 函数

这里是模型学习和成长的核心地带。一个epoch(一轮)的训练就是在这个函数里完成的。代码的精髓在于梯度累积的实现。

def train_epoch(model, dataloader, optimizer, ..., scaler, lr_scheduler):
    model.train() # 1. 切换到“训练模式”
    optimizer.zero_grad(set_to_none=True) # 2. 在循环开始前清空梯度

    for batch_idx, batch in progress_bar:
        # ... 数据准备 ...
        
        # 3. 开启混合精度计算上下文
        with autocast(enabled=USE_AMP, dtype=AMP_DTYPE, ...):
            logits = model(model_input) # 4. 模型进行预测
            # 5. 计算损失,并根据累积步数进行缩放
            loss = criterion(logits.view(-1, ...), targets.view(-1))
            loss = loss / GRADIENT_ACCUMULATION_STEPS
        
        # 6. 反向传播(累积梯度)
        scaler.scale(loss).backward()
        
        # 7. 达到累积步数,执行一次参数更新
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)    # 将梯度缩放回来以便裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM) # 裁剪梯度
            scaler.step(optimizer)        # 更新模型参数
            scaler.update()               # 更新缩放器
            lr_scheduler.step()           # 8. 更新学习率
            optimizer.zero_grad(set_to_none=True) # 清空梯度,为下一轮累积做准备
            
            # ... 日志记录与模型保存 ...
  1. model.train():切换模型到训练状态,这会启用像Dropout这样的训练特有层。
  2. optimizer.zero_grad():在循环开始前,先把梯度清零。注意,在梯度累积逻辑中,它也会在每次参数更新后被调用。
  3. with autocast(...):启用混合精度,PyTorch会自动将计算转换为bfloat16来提速。
  4. model(model_input):把数据喂给模型,让它做出预测。
  5. loss = loss / GRADIENT_ACCUMULATION_STEPS:计算损失后,将其除以累积步数。这是为了保证累积了多次的梯度在量级上与没有累积时大致相当。
  6. scaler.scale(loss).backward():这是整个训练过程的“灵魂”!它根据缩放后的损失,计算出模型每个参数应该如何调整(即计算梯度)。由于我们没有清空梯度,所以每次调用,新的梯度会累加到旧的梯度上。
  7. if ... % GRADIENT_ACCUMULATION_STEPS == 0:这是梯度累积的核心判断。只有当处理了足够数量的“微批次”后,才执行真正的“知错就改”步骤:
    • scaler.unscale_clip_grad_norm_:在更新前,先将梯度恢复原始大小,并进行裁剪,防止模型“用力过猛”。
    • scaler.step(optimizer):根据累积了多次的梯度,真正地去更新模型的每一个参数。这一步相当于用一个大批次的数据在训练。
  8. lr_scheduler.step():我们不希望学习率一成不变。lr_scheduler是学习率的“配速员”,它会根据预设策略(如预热+余弦退火)调整学习率。关键是,它在参数更新(optimizer.step)之后调用,而不是在每个微批次后。

定期“考试” —— validate_epoch 函数

光闷头学习不行,我们得定期“考试”来检验模型学习成果。validate_epoch函数就扮演了这个角色。

def validate_epoch(model, dataloader, ...):
    model.eval() # 1. 切换到“评估模式”
    
    with torch.no_grad(): # 2. 关闭梯度计算
        for batch_idx, batch in progress_bar:
            # ... 数据准备 ...
            with autocast(...): # 同样使用混合精度进行推理
                logits = model(model_input)
                loss = criterion(...) # 计算损失
    
    perplexity = math.exp(avg_loss) # 3. 计算困惑度
    # ... 记录日志 ...
  1. model.eval():切换模型到验证状态,这会关闭Dropout等层,确保评估结果的稳定性。
  2. with torch.no_grad():验证时不需要学习,所以我们关闭梯度计算。这能节省大量显存和计算资源,让评估跑得更快。
  3. 困惑度(Perplexity):这是评估语言模型最常用的指标之一。我们可以直观地理解为:模型对于接下来要出现哪个词,有多“困惑”。困惑度越低,说明模型对文本的预测能力越强,学得越好。

断点续训与模型保存

训练大模型动辄几天甚至几周,如果中途断电或者程序崩溃,一切从头再来,那绝对是欲哭无泪。所以,断点续训(Resume Training)模型保存(Checkpointing) 是我们的“救生索”。我们的脚本实现了非常完善的机制。

1. 断点续训

main函数中,我们有这样一段逻辑:

# 设置为你想恢复的checkpoint的路径
RESUME_CHECKPOINT_PATH: Optional[str] = None 

# ... 在 main 函数里 ...
if RESUME_CHECKPOINT_PATH and os.path.exists(RESUME_CHECKPOINT_PATH):
    print(f"Resuming training from checkpoint: {RESUME_CHECKPOINT_PATH}")
    checkpoint = torch.load(RESUME_CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict']) # 恢复模型
    optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 恢复优化器
    if USE_SCALER:
        scaler.load_state_dict(checkpoint['scaler_state_dict']) # 恢复混合精度缩放器
    global_step = checkpoint['global_step']
    start_epoch = checkpoint['epoch'] + 1
    
    # 将学习率调度器快进到正确的步数
    for _ in range(global_step):
        lr_scheduler.step()
    # ...

当我们设置了RESUME_CHECKPOINT_PATH,脚本启动时会先检查这个文件。如果存在,它会像读取“游戏存档”一样,恢复所有关键状态:

  • 模型的权重 (model_state_dict)。代码还智能地处理了模型是否被DataParallel等包装过的情况,自动去掉module.前缀,非常稳健。
  • 优化器的状态 (optimizer_state_dict,这很重要,因为它包含了AdamW等优化器的动量信息)。
  • 混合精度缩放器的状态 (scaler_state_dict)。
  • 训练进度 (global_stepepoch)。
  • 学习率调度器状态:通过一个循环,将调度器“快进”到global_step,确保学习率曲线能够完美衔接。

这样,训练就能无缝衔接,从上次中断的地方继续,分毫不差!

2. 定期保存与滚动删除

训练过程中,我们不能只等结束了再保存。

# 配置
SAVE_INTERVAL = 10000 // GRADIENT_ACCUMULATION_STEPS # 每多少个 optimizer_step 保存一次
MAX_CHECKPOINT_TO_KEEP = 99 # 最多保留几个

# ... 在 train_epoch 函数里,参数更新后 ...
if global_step > 0 and global_step % SAVE_INTERVAL == 0:
    # ... 保存 checkpoint ...
    torch.save(save_dict, checkpoint_path)
    
    # ... 管理队列,自动删除旧的checkpoint ...
    if saved_checkpoints_queue.maxlen is not None:
        if len(saved_checkpoints_queue) == saved_checkpoints_queue.maxlen:
            path_to_remove_on_disk = saved_checkpoints_queue.popleft()
            os.remove(path_to_remove_on_disk)
        saved_checkpoints_queue.append(checkpoint_path)
  • SAVE_INTERVAL:我们设置每训练一定步数(注意是参数更新的步数,不是微批次数),就自动存一次档。
  • MAX_CHECKPOINT_TO_KEEPdeque:模型存档非常占硬盘空间。我们不希望无限保存下去。这里使用了一个deque(双端队列)实现滚动保存。它永远只保留最近的MAX_CHECKPOINT_TO_KEEP个存档。每当一个新的存档产生,最老的那一个就会被自动删除。这极大地节省了我们的硬盘空间。
  • 恢复时自动填充队列:更棒的是,如果从断点恢复,脚本会扫描检查点目录,将已有的最近的检查点文件加载到这个管理队列中,避免了恢复后队列为空的尴尬。
  • 保存最佳模型:除了定期保存,我们还会在每次验证后(由SAVE_INTERVAL触发),如果发现模型的验证损失创下新低,就单独保存一个名为best_model.pt的文件。这确保我们总能拿到训练过程中表现最好的那个模型!

启动引擎 —— main 函数全览

main函数就是我们的总指挥。它按照顺序,一步步地把所有模块串联起来:

  1. 加载“三件套”:加载分词器(Tokenizer)、数据集(Dataset)和数据加载器(DataLoader)。DataLoader使用了pin_memorypersistent_workers等优化参数,可以提升数据加载效率。
  2. 初始化“核心组件”:创建模型实例、损失函数(CrossEntropyLoss)、优化器(AdamW)和学习率调度器。值得注意的是,学习率调度器(lr_scheduler)的总步数MAX_TRAIN_STEPS是根据数据集大小、Epoch数和梯度累积步数精确计算出来的,这让学习率的衰减过程更加合理。
  3. 执行“断点续训”检查:如上所述,尝试从存档恢复,并初始化检查点管理队列。
  4. 启动“训练-验证”主循环
    for epoch in range(start_epoch, NUM_EPOCHS):
        # 训练一轮,期间会按SAVE_INTERVAL进行保存和验证
        train_loss, global_step, best_val_loss = train_epoch(...)
        
        # 每个epoch结束后,进行一次完整的验证
        val_loss, val_perplexity = validate_epoch(...)
        
        # 检查是否是最佳模型并保存
        if val_loss < best_val_loss:
            # ... save best_model.pt ...
    
  5. 结束:训练完成后,关闭日志记录器,打印结束信息。

总结与展望

到这里,我们学习了“炼丹炉”的搭建。现在,我们不仅知道代码在做什么,更重要的是,理解了为什么要这么做。这份脚本凝聚了许多大型模型训练的最佳实践。

我们掌握了:

  • 如何通过超参数配置来定义我们的训练任务。
  • 使用梯度累积,在有限的显存下实现大批次训练。
  • 混合精度(AMP)与梯度缩放(GradScaler)如何为我们提速省显存。
  • 验证集和困惑度如何客观地评价模型性能。
  • 如何实现一个非常健壮的断点续训和智能检查点管理策略。

至此,我们万事具备,是时候让显卡开始“燃烧”了!在下一篇文章中,我们将把所有部分整合起来,正式开启训练,并观察模型的学习曲线,见证一个语言模型从“随机乱说”到“言之有物”的诞生过程。敬请期待!


关注我的公众号不走丢

附录

GitHub链接:github.com/JimmysAIPG/…