fsdp训练llama-2-70b

1,149 阅读2分钟

基于llama-recipes库(7.17号版本)训练过程中发现问题,内存不够

现在已经修复了,在8月1号d8a81bb提交中已经修复

关于fsdp,这篇写的很好

在pytorch的issue中也有人提到fsdp加载问题

解决方法

如果是8卡,启动8个进程,只在其中一个进程加载模型参数,其他进程以空参数初始化,即torch.device("meta")(和init_empty_weights有什么区别?)。然后再放入GPU,即FSDP实例化过程中,按照每块卡该分配到的参数初始化。

预备知识:

torch meta device

device='meta'

How 🤗 Accelerate runs very large models thanks to PyTorch (huggingface.co)

DEBUG

参数量 φ\varphi

训练方式模型参数梯度优化器状态激活值总计
fp321×4φ1{\times}4\varphi1×4φ1{\times}4\varphi2×4φ2{\times}4\varphi?16φ16\varphi
混合精度1×2φ1{\times}2\varphi1×2φ1\times2\varphi(2×4+1×4)φ=12φ(2{\times}4+1{\times}4)\varphi=12\varphi (fp32优化器+存的一份fp32原参数)?16φ16\varphi

ddp需要保留一份梯度的备份解释

一个错误的debug例子

为了方便debug,采用gpt2作为测试的模型(参数量124M) 采用2块卡(2个进程)FSDP加载gpt2,每块卡占用2200M显存,与预估一致

一个参数占用32bit,即4字节

0.124B * 4 / 2 = 0.248G = 2480M

但是采用4块卡(4个进程)FSDP加载gpt2,每块卡占用2200M显存,和2块卡一样,不能理解。

显卡数量每块卡占用显存
22266MiB
42148MiB
62108MiB
82088MiB

造成上面问题的原因:没有设置对auto_wrap_policy.

GPT2 应该占用多大显存

bin文件548MB,说明存放的是fp32 如果采用单卡 是占用1552M(按理说是4960M 推理1918MiB

fp16占用1278M只能推测是更底层对小参数模型做了显存优化。

1B(用全连接测试)的参数是4832M

4φ/1024/1024/10244\varphi/1024/1024/1024

标题fp32fp16
gpt2(0.124B)(124439808)1552M(理论值4649M)1278M(理论值2325M)
全连接(0.124B)(122061000)1492M(理论值25102M)13896M(理论值12551M)
llama2(7B)(6738415616)(理论值25176M)(理论值12588M)
全连接(7B)(6738415616)(理论值25102M)(理论值12551M)