并行化
数据并行(Data Parallelism)
在数据并行训练中,数据集被分割成几个碎片,每个碎片被分配到一个设备上。这相当于沿批次(Batch)维度对训练过程进行并行化。每个设备将持有一个完整的模型副本,并在分配的数据集碎片上进行训练。典型的数据并行实现:PyTorch DDP。
模型并行(Model Parallelism)
通常有两种类型的模型并行:张量并行和流水线并行
张量并行(Tensor Parallelism)
利用了分块矩阵的原理,专注于解决单层参数过大的问题。垂直划分模型,层内并行
流水线并行(Pipeline Parallelism)
将模型一层一层分开,不同层放入不同 GPU 进行计算,训练设备容易出现空闲状态,因为后一个阶段需要等待前一个阶段执行完毕。水平划分模型,层间并行
混合精度
模型参数(fp16)、模型梯度(fp16)和Adam状态(fp32的模型参数备份,fp32的momentum和fp32的variance)。假设模型参数量为W,则共需要2W+2W+(4W+4W+4W)=16W字节存储。
梯度累积
- 梯度累积是一个比较简单的优化技术,其从batch size的层面来降低显存占用的。一般情况下,显存的占用直接受到输入数据的影响,包括batch size、sequence length等,如果显存溢出,我们最直接的做法就是将batch size调低。但是对于预训练和指令微调时,扩大batch size是提高模型训练效果的重要因素,降低batch size可能会降低模型的效果。
- 为了不降低batch size,可以采用梯度累积的方法。梯度累积是指在前向传播之后所计算梯度并不立刻用于参数更新,而是接着继续下一轮的前向传播,每次计算的梯度会暂时存储下来,待在若干次前向传播之后,一并对所有梯度进行参数更新。因此梯度累积相当于是拿时间换空间。
梯度检查点 (Gradient Checkpointing)
- 由于模型反向传播需要中间结果计算梯度,大量中间结果占用大量显存。
- Checkpointing 思路是保存部分隐藏层的结果(作为检查点),其余的中间结果直接释放。当反向传播需要计算梯度时,从检查点开始重新前向传播计算中间结果,得到梯度后再次释放,因此梯度检查点相当于是拿时间换空间。
ZeRO(Zero Redundancy Optimizer)
零冗余优化器是一种用于大规模分布式深度学习的新型内存优化技术。在普通的数据并行策略中,每个 GPU 都独立地维护一组完整的模型参数,计算与通信效率较高,但内存效率较差。这个问题在训练大型模型时尤为突出。ZeRO 可以有效地减少显存消耗量,这意味着在同样的显存下,可以训练更大的模型。
显存分析
训练深度学习模型时的显存消耗可以分为两大部分:
- 模型状态(model states)。对于大型模型来说,大部分显存消耗都是被模型状态占用的,主要包括三部分:优化器的状态(Optimizer States)、梯度(Gradients)、参数(Parameters)。三者简称为 OPG。
- 残余状态(residual states)。剩余状态(residual states): 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)
ZeRO优化阶段
ZeRO分为三个阶段,分别对应 O、P 和 G。每个 GPU 仅保存部分 OPG,三个阶段逐级递加: