MindSpore Transformers 断点续训

1 阅读2分钟

​断点续训是长周期大模型训练的核心保障,MindSpore Transformers(MindFormers)通过 Checkpoint 2.0 机制实现全状态保存与精准恢复,可无缝接续模型权重、优化器状态、学习率调度、数据迭代位置与训练步数,避免异常中断导致的算力浪费昇思MindSpore。其支持单机 / 分布式、扩缩容、增量续训等场景,配置简洁、兼容性强昇思MindSpore。

一、核心原理

断点续训的关键是全状态保存:训练时按固定步长保存 Checkpoint,包含模型参数、优化器状态、动态学习率、当前 epoch/step、随机种子等信息。中断后,通过load_checkpoint加载断点文件,resume_training=True激活续训,系统自动读取latest_checkpointed_iteration.txt定位最新步数,从断点处精准接续昇思MindSpore。分布式训练下,支持策略文件加载与自动切分,适配卡数变更场景昇思MindSpore。

二、配置与代码实现

1. 训练配置(YAML)

# 基础训练配置
model:
  model_config:
    type: LlamaConfig
    seq_length: 1024
    vocab_size: 32000
  arch:
    type: LlamaModel

# 断点保存配置
callbacks:
  - type: CheckpointMonitor
    save_checkpoint_steps: 100  # 每100步保存1次
    keep_checkpoint_max: 10     # 保留最新10个断点
    integrated_save: True       # 全状态保存(含优化器/调度器)
    async_save: True            # 异步保存,不阻塞训练

# 断点续训核心参数
load_checkpoint: "./output/checkpoint"  # 断点目录
resume_training: True                   # 开启续训
load_ckpt_format: "safetensors"         # 权重格式(可选ckpt)

2. 代码实现(Python)

import mindspore as ms
from mindformers import Trainer, MindFormerConfig
from mindformers.tools.logger import logger

# 1. 初始化环境与配置
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
config = MindFormerConfig("configs/llama/pretrain_llama2_7b.yaml")

# 2. 手动配置续训(可选,覆盖YAML)
config.resume_training = True
config.load_checkpoint = "./output/checkpoint"  # 断点路径
config.callbacks[0].save_checkpoint_steps = 100

# 3. 初始化训练器(自动加载断点)
trainer = Trainer(
    config=config,
    task="text_generation",
    model_name="llama2_7b",
    train_dataset="/path/to/train_data.mindrecord"
)

# 4. 启动续训
logger.info("Start resume training from checkpoint...")
trainer.train()

# 5. 分布式续训(多卡)
# 启动脚本:bash scripts/msrun_launcher.sh "python run_mindformer.py --config xxx.yaml --run_mode train" 8

3. 关键 API 说明

  • CheckpointMonitor:回调类,控制断点保存步长、数量、全状态 / 异步保存。
  • load_checkpoint:指定断点路径(文件夹 / 指定 iteration 子文件夹)昇思MindSpore。
  • resume_training:续训开关,True时自动恢复训练状态昇思MindSpore。
  • auto_trans_ckpt:分布式改卡续训时开启,自动转换权重策略昇思MindSpore。

三、使用步骤

  1. 正常训练:配置CheckpointMonitor,自动保存断点至output/checkpoint
  2. 中断处理:训练异常终止(断电、硬件故障)后,保留断点目录与latest_checkpointed_iteration.txt昇思MindSpore。
  3. 启动续训:修改配置resume_training=True,指定load_checkpoint路径,重新启动训练脚本。
  4. 分布式适配:改卡续训时,添加src_strategy_path_or_dirauto_trans_ckpt=True,自动合并 / 切分权重。

四、注意事项

  • 续训时模型结构、超参、数据集需与原训练一致,否则加载失败。
  • 分布式场景下,不改变策略可直接续训;改卡需开启auto_trans_ckpt
  • 建议开启integrated_save=True,确保优化器、学习率调度器状态完整保存。
  • 保留latest_checkpointed_iteration.txt,避免手动删除导致无法定位最新断点。