一文读懂Zero-Stage的计算过程

184 阅读3分钟

背景

随着模型规模越来越大,训练过程中对计算资源的需求呈现指数级增长。为了解决这些问题,人们提出了Zero-Stage的训练方法,旨在通过对数据并行中的冗余数据进行优化,来显著减少显存占用。

实现方式

Zero主要通过对模型状态(参数、梯度和优化器状态)进行切分存储来实现显存优化,它分为三个阶段,即 Zero-Stage1、Zero-Stage2 和 Zero-Stage3,每个阶段的优化程度逐渐加深。

Stage1

首先,从优化器参数开始优化。把优化器参数切分成N份,每块GPU上各自维护一份。如下图,假设W=fp16, G=fp16, O=fp32,整体计算过程

  • 每块GPU上存一份完整的参数W.将第一个batch的数据分成3份,每块GPU上各一份,做一轮forward和backward之后,各自算的一份梯度。
  • 对梯度做一次allreduce,得到完整的梯度G, 产生单卡通讯2产生单卡通讯量 2Φ 

image.png

  • 得到完整梯度G之后,就可以对W进行更新。W的更新由优化器状态和梯度共同决定。由于每块GPU上只保存部分优化器状态,因此只能将相应的W进行更新(蓝色部分) image.png
  • 此时,每块GPU上都有部分W没有完成更新(图中白色部分),所以我们需要对W做一次Allgather,从别的GPU上把更新好的部分W取回来。产生单卡通讯量Φ  做完Stage1后,假设GPU个数为Nd, 显存和通讯量的情况如下:

image.png

Stage2

Stage2把梯度也做了进一步切分,此时Stage2的整体流程如下:

  • 每块GPU上存一份完整梯度W,将一个batch的数据分为3份,每块GPU上各维护一份,做一轮forward和backward之后,算得一份梯度(下图中绿色+白色
  • 对梯度做一次reduce-scatter, 保证每个GPU上所维持的那块梯度是聚合梯度。例如对GPU1, 它负责维护G1,因此其他的GPU只需要把G1对应位置的梯度发给GPU1做加总就可。汇总完毕后,白色块对GPU无用,可以从显存中移除。产生单卡通讯量Φ
  • 每块GPU用自己的O和G去更新对应的W.更新完毕后,每块GPU维持了一块更新完毕的W.同理,对W做一次allgather,将别的GPU算好的W同步到自己这来。产生单卡通讯量Φ

image.png 下面是Stage2的显存和通讯量分析

image.png 和朴素DP对比,存储变为了1/8, 单卡通讯量持平

Stage3

Stage3进一步对参数进行了切分,整体流程如下

  • 每块GPU上只保存部分参数W。将一个batch的数据分为3份,每块GPU各自维护一份
  • 做forward时,对W做一次allgather,取回分布在别的GPU上的W,得到一份完整的梯度W, 产生单卡通讯量Φ。forward做完,立刻把不是自己维护的W丢弃
  • 做backward时,算得一份完整的梯度G,对G做一次reduce-scatter,从别的GPU上聚合自己维护的那部分梯度,产生单卡通讯量Φ。聚合操作结束后,立刻把不是自己维护的G丢弃
  • 用自己维护的O和G,更新W。由于只维护部分W, 因此无需对W做任何allreduce操作。

image.png stage3显存和通讯量如下:至此,我们用1.5倍通讯开销,换回120倍显存。只需要地图计算和异步更新做的好,通讯时间可以被计算时间隐藏,因此这样实现的额外通讯开销,也是划算的。

image.png

参考文章

1、zhuanlan.zhihu.com/p/618865052