基于llama-recipes库(7.17号版本)训练过程中发现问题,内存不够
现在已经修复了,在8月1号
d8a81bb提交中已经修复
解决方法
如果是8卡,启动8个进程,只在其中一个进程加载模型参数,其他进程以空参数初始化,即torch.device("meta")(和init_empty_weights有什么区别?)。然后再放入GPU,即FSDP实例化过程中,按照每块卡该分配到的参数初始化。
预备知识:
torch meta device
device='meta'
How 🤗 Accelerate runs very large models thanks to PyTorch (huggingface.co)
DEBUG
参数量
| 训练方式 | 模型参数 | 梯度 | 优化器状态 | 激活值 | 总计 |
|---|---|---|---|---|---|
| fp32 | ? | ||||
| 混合精度 | (fp32优化器+存的一份fp32原参数) | ? |
ddp需要保留一份梯度的备份解释。
一个错误的debug例子
为了方便debug,采用gpt2作为测试的模型(参数量124M)
采用2块卡(2个进程)FSDP加载gpt2,每块卡占用2200M显存,与预估一致
一个参数占用32bit,即4字节
0.124B * 4 / 2 = 0.248G = 2480M
但是采用4块卡(4个进程)FSDP加载gpt2,每块卡占用2200M显存,和2块卡一样,不能理解。
| 显卡数量 | 每块卡占用显存 |
|---|---|
| 2 | 2266MiB |
| 4 | 2148MiB |
| 6 | 2108MiB |
| 8 | 2088MiB |
造成上面问题的原因:没有设置对auto_wrap_policy.
GPT2 应该占用多大显存
bin文件548MB,说明存放的是fp32
如果采用单卡 是占用1552M(按理说是4960M
推理1918MiB
fp16占用1278M,只能推测是更底层对小参数模型做了显存优化。
1B(用全连接测试)的参数是4832M
| 标题 | fp32 | fp16 |
|---|---|---|
| gpt2(0.124B)(124439808) | 1552M(理论值4649M) | 1278M(理论值2325M) |
| 全连接(0.124B)(122061000) | 1492M(理论值25102M) | 13896M(理论值12551M) |
| llama2(7B)(6738415616) | (理论值25176M) | (理论值12588M) |
| 全连接(7B)(6738415616) | (理论值25102M) | (理论值12551M) |