使用NVFP4低精度模型训练实现更高吞吐量且不损失精度
随着AI模型和数据集规模的持续增长,仅依赖高精度的BF16训练已不再足够。训练吞吐量预期、内存限制和不断上升的成本等关键挑战正成为扩展Transformer模型的主要障碍。
使用更低精度的训练可以应对这些挑战。通过降低计算过程中使用的数值精度,GPU可以在每个周期内处理更多操作,从而提高训练效率并降低成本。
本文直接对比了以下三种低精度训练格式与已确立的BF16精度训练,对比范围涵盖数千亿token的预训练运行和下游基准测试:
- 8位浮点每张量当前缩放 (FP8-CS)
- 使用FP8的混合精度训练 (MXFP8)
- 使用某机构NeMo Megatron Bridge的NVFP4精度训练(该库为某机构NeMo框架的一部分)
本文展示了实用的大规模结果,说明低精度训练如何实现高达约1.6倍的吞吐量、显著的内存节省以及几乎相同的模型质量,且使用的生产就绪配方可立即采用。
什么是低精度训练?
低精度训练在模型训练过程中使用位数更少的数值格式来表示权重和激活值。这减少了内存带宽和计算需求,使GPU能够在每个周期处理更多操作,从而显著提高训练吞吐量。
低精度格式
- FP8-CS:使用基于当前训练步骤中每个张量的统计特性得出的缩放因子,将FP8应用于线性层。
- MXFP8:通过针对某机构Blackwell架构优化的块级缩放扩展了FP8方法,每个块覆盖32个张量元素。
- NVFP4:通过对张量值使用4位格式并采用分层两级缩放策略,进一步提高了内存效率和吞吐量。
低精度训练能否在大规模下达到BF16的精度?
为了验证低精度训练对实际大规模模型预训练的实际影响,团队在两个广泛使用的稠密Transformer架构上评估了训练收敛性和下游任务性能:Llama 3 8B和一个某机构内部研究型8B模型(Research-8B,具有与Llama 3 8B类似的稠密分组查询注意力架构)。模型在1万亿token上进行了训练。
实验设置:隔离精度的影响
进行了以下大规模预训练实验:
- 四种数值精度:BF16(基线)、FP8-CS、MXFP8和NVFP4
- 两种模型架构:Llama 3 8B和Research-8B
- 训练软件和硬件:在某机构B200 GPU上的NeMo Megatron Bridge
- 两个数据集:Lingua DCLM数据集和一个内部数据集。Llama 3 8B在两个数据集上都进行了训练,Research-8B在内部研究数据集上进行了训练
收敛行为:各精度下的训练稳定性
图2、3和4展示了两种模型和数据集上的训练和验证损失曲线。低精度训练与BF16基线紧密跟随,展示了各精度下稳定且一致的收敛。在所有情况下,NVFP4显示出略高的损失,但下游精度未受影响。详情见表1。
下游评估:精度得以保持
为了评估低精度训练是否影响实际性能,我们在标准下游基准上评估了所有预训练模型。所有评估均在BF16精度下运行,以隔离训练精度的影响。
表1显示了结果。尽管训练和验证损失存在微小差异,所有低精度格式都实现了与BF16相当的下游任务精度。
表1. Llama 3 8B和Research-8B在BF16、FP8-CS、MXFP8和NVFP4训练下的下游任务精度(%)
| 模型 | 数据集 | 精度 | MMLU (↑) | HellaSwag (↑) | WinoGrande (↑) | ARC-C (↑) |
|---|---|---|---|---|---|---|
| Llama 3 8B | DCLM | BF16 | 45.98 | 76.44 | 70.17 | 51.28 |
| FP8-CS | 46.00 | 75.25 | 70.24 | 49.91 | ||
| MXFP8 | 46.56 | 75.46 | 71.27 | 51.11 | ||
| NVFP4 | 45.64 | 75.59 | 69.38 | 51.28 | ||
| Llama 3 8B | 内部数据集 | BF16 | 52.73 | 75.71 | 67.88 | 51.37 |
| FP8-CS | 52.46 | 75.65 | 70.17 | 54.52 | ||
| MXFP8 | 53.70 | 75.54 | 69.69 | 51.62 | ||
| NVFP4 | 52.83 | 75.04 | 71.98 | 53.58 | ||
| Research-8B | 内部数据集 | BF16 | 53.00 | 76.98 | 70.40 | 55.89 |
| FP8-CS | 52.62 | 75.81 | 70.80 | 54.44 | ||
| MXFP8 | 52.38 | 76.55 | 69.77 | 53.58 | ||
| NVFP4 | 52.21 | 76.19 | 70.32 | 54.95 |
关键见解
- 低精度训练与BF16收敛相匹配:FP8、MXFP8、NVFP4实现的预训练和验证损失非常接近BF16,退化极小。
- 下游精度得以保持:在所有模型和基准测试中,低精度训练提供的下游任务精度与BF16相当,证明降低精度能保持模型有效性。
- MXFP8略优于标准FP8:这很可能归因于其更细粒度的缩放机制,能更好地捕捉张量内的局部动态范围。
- 经过适当校准的NVFP4尽管压缩激进,仍能提供有竞争力的结果:经验性的最佳配置为:AdamW ϵ=1e-8,LR=6e-4 → 6e-6,GBS=768。
- 选择性BF16层对NVFP4至关重要:消融研究表明,完全使用NVFP4的模型会发散。稳定训练需要将某些层保持在BF16,特别是靠近网络末端的位置,以减轻NVFP4量化误差。在这些实验中,将最后四个Transformer层保持在BF16就足够了。
FP8、MXFP8和NVFP4训练的优势
低精度格式在训练吞吐量和内存效率方面都带来了明显优势,从而在某机构Blackwell GPU上实现更快的端到端训练和更好的可扩展性。
表2. 在某机构GB200 NVL72上训练Llama 3 8B的吞吐量对比,显示NVFP4相比BF16最高可达1.59倍加速
| 精度 | 微批次大小 | 吞吐量 (TFLOP/s/GPU) | 相对于BF16的加速比 |
|---|---|---|---|
| BF16 | 2 | 1165 | – |
| FP8-CS (F1L1) | 2 | 1547 | 1.33x |
| MXFP8 | 2 | 1540 | 1.32x |
| NVFP4 (F0L4) | 4 | 1850 | 1.59x |
更快的端到端训练
使用8位或4位数值格式通过使GPU在每个时钟周期处理更多操作,大幅降低了计算开销。吞吐量提升相比BF16基线最高可达1.59倍(表2)。这些提升直接转化为大规模模型更快的训练时间。
GPU内存节省和更好的可扩展性
使用较低位宽格式减少了权重和激活值的内存占用,允许在相同硬件上使用更大的模型或批次大小。NVFP4的效率使得预训练期间的微批次大小翻倍(从2到4),直接提高了吞吐量和可扩展性。
表3详细列出了训练组件的内存使用情况。低精度格式显著减少了参数和激活存储,同时保留FP32优化器状态,从而在不影响训练稳定性的情况下实现更高的吞吐量和更大的批次大小。
表3. 不同精度格式下各训练组件的内存占用
| 组件 | 精度 |
|---|---|
| 优化器 | FP32 |
| 参数 | FP16 / BF16 / FP8 / FP4 |
| 梯度 | FP16 / BF16 / FP8 |
| 动量 | FP32 |
| 方差 | FP32 |
| 主参数 | FP32 |
| 其他 | FP32 |
使用NeMo Megatron Bridge进行低精度训练
NeMo Megatron Bridge是一个开源的PyTorch原生库,属于某机构NeMo框架的一部分。它在Hugging Face和Megatron Core模型检查点之间提供双向连接。它提供了优化训练和多节点并行机制,用于以最大吞吐量预训练、SFT和LoRA调优生成式AI模型。
使用NeMo Megatron Bridge库采用低精度训练非常简单。可以使用针对各种模型的即用型低精度配方,通过更改单个配置标志来试验不同的精度格式。Llama 3 8B的示例如下:
from megatron.bridge.recipes.llama import llama3_8b_low_precision_pretrain_config as low_precision_pretrain_config
from megatron.bridge.training.gpt_step import forward_step
precision = "bf16_with_fp8_current_scaling_mixed" # 可选值:["bf16_with_mxfp8_mixed", "bf16_with_fp8_current_scaling_mixed", "bf16_with_nvfp4_mixed"]
cfg = low_precision_pretrain_config(
mixed_precision_recipe = precision,
train_iters = 100,
lr_warmup_iters = 10,
lr_decay_iters = 90,
mock = True, # 使用模拟数据集
)
pretrain(config=cfg, forward_step_func=forward_step)
可以轻松地在精度格式之间切换,以评估性能、内存节省和收敛行为,而无需修改模型代码或优化器逻辑。
更快地训练,高效地扩展
与广泛采用的BF16相比,诸如当前缩放的FP8、MXFP8和NVFP4等低精度训练格式为更快、更高效的深度学习训练提供了令人兴奋的新途径。它们在速度和内存节省方面的优势为训练更大、更复杂的模型打开了大门。来自Llama 3 8B和内部研究模型的实证证据证实,低精度训练在预训练指标和下游任务上都与BF16性能相匹配。
开始使用低精度训练
随着模型规模持续扩大,低精度训练将成为构建下一代模型的基础。借助原生某机构Blackwell GPU支持和NeMo Megatron Bridge中的生产就绪低精度配方,可以立即尝试这些技术。
如需快速上手,请尝试Megatron Bridge训练教程笔记本。该教程端到端地演示了如何使用这些低精度配方,并展示了它们如何显著加速训练工作负载。FINISHED