Distributed Data Parallel 原理及应用
Distributed Data Parallel 原理
训练流程:
- Step1 使用多个进程,每个进程都加载数据和模型
- Step2 各进程同时进行前向传播,得到输出
- Step3 各进程分别计算Loss,反向传播,计算梯度
- Step4 各进程间通信,将梯度在各卡同步
- Step5 各进程分别更新模型
分布式训练中的基本概念
- group:进程组,一个分布式任务对应一个进程组,一般就是所有卡都在一个组里
- world size:全局的并行数,一般情况下等于总的卡数
- node:节点,可以是一台机器,或者一个容器,节点内包含多个GPU
- rank(global rank):整个分布式训练任务内的进程序号
- local rank:每个node内部的相对进程序号
分布式训练中的通信
什么是通信:
- 在分布式模型训练中,通信是不同计算节点之间进行信息交换以协调训练任务的关键组成部分。
通信类型:
- 点对点通信:将数据从一个进程传输到另一个进程称为点对点通信
- 集合通信:一个分组中所有进程的通信模式称之为集合通信
- 6种通信类型:Scatter、Gather、Reduce、All Reduce、Broadcast、All Gather
Distributed Data Parallel 实现注意细节
数据部分
- 如果是在训练进程内对数据集进行划分,注意保证数据划分的一致性,可以通过随机种子控制
- 分布式采样器会为了保证每个进程内数据大小一致,做额外的填充,评估指标可能会存在误差
书写逻辑
- 将分布式的代码看作单进程的代码即可,只是需要分布式的数据采样器以及启动略有不同
- print打印的都是各自进程内的信息,需要全局的信息则需要自行调用通信计算结果
- 数据放置到指定设备上时需要注意使用正确的device_id,一般用local_rank