大模型训练时底层显存占用情况详解

935 阅读5分钟

计算机基础知识

在了解模型训练的底层显存分配情况之前,我们要简单介绍一下对应的计算机基础知识,我们对于内存单位的换算一般基于以下公式:

1GB=1024MB

1MB=1024KB

1KB=1024Byte

1Byte =8Bit

在我们训练模型的时候通常会用到 FP32 单精度和 FP16 半精度的单位,它们使用公式可以表示的单位换算如下:

FP32=32Bits=4Byte

FP16=16Bits=2Byte

所以我们得到了一个重要的结论就是 FP32 和 FP16 分别占用 4 字节和 2 字节。

显存占用五大部分

1 模型的输入输出

我们这里以 llama-13B 的模型为例,参数类型是 FP16 ,也就是 2 字节,具体超参数如下:

b,也就是 batch_size ,设置为 1

s,也就是 sequence_len ,设置为 1024

h,也就是 hidden_size ,设置为 5120

经过 embedding 之后,最终将 Byte 换算成 MB ,输入占用的显存如下:

输入 :b * s * h * 2 / 1024 / 1024 = 10 MB

输出 :同理也是 10MB

总共:20MB

输入占用显存 10MB ,同理输出占用的也是 10MB ,也就是模型的输入和输出一共占用 20MB。

2 模型参数

我们要先进行一个简单的换算来推理出一个实用的结论:

1B=1000^3

1GB=1024^3Byte

因为 1000^3 和 1024^3 非常接近,所以我们可以得出一个结论:

如果有 1B 参数量,每个参数是 1 个字节,我们近似约等占用 1GB 显存

模型参数一共 13B 个,每个参数类型是 FP16,占用 2 字节,也就是 llama-13B 一共近似占用显存如下:

13 * 2 = 26GB

如果模型参数类型是 FP32 ,llama-13B 一共近似占用显存 52GB 。

3 优化器

目前训练大模型最常用的优化器就是 Adam (或者 AdamW ),使用 Adam 的时候,需要为每个参数保留 FP32 的梯度指数平滑值和 FP32 的梯度平方指数平滑值,同时需要保存一份 FP32 的模型参数。所以总共消耗的显存如下:

梯度指数平滑值: 13 * 4 = 52GB

梯度平方指数平滑值: 13 * 4 = 52GB

模型参数: 13 * 4 = 52GB

总共:156GB

可以看出来优化器部分占用的显存相当惊人。

4 梯度

模型的参数有多少个,梯度就会有多少个,我们这里使用 llama-13B ,参数类型是 FP16 ,也就是 2 个字节,所以梯度显存占用情况如下:

13B * 2 = 26G

同理如果模型参数类型是 FP32 ,llama-13B 的梯度一共近似占用显存 52GB 。

5 激活值

我们把激活值占用显存的情况放在最后介绍,是因为前面四个部分的占用是方便估算的,而激活值的显存往往是动态的,通常与 batch_size 和 sequence_len 等超参数密切相关,变化幅度较大。甚至我们可以设置参数不保留激活值,也就是不占用显存。

我们这里要查看论文,这里已经有现成的公式可以套用,具体细节自己可以查看论文介绍。

我们使用第一个公式,也就是在不进行并行的常规情况下, Transformer 每层的激活值占用显存的情况,主要与以下的因素有关:

b,也就是 batch_size ,设置为 1

s,也就是 sequence_len ,设置为 1024

h,也就是 hidden_size ,设置为 5120

l,也就是 num_layers ,设置为 40

a,也就是 num_attention_heads ,设置为 40

同样我们使用 llama-13B ,参数类型是 FP16 ,那么此时所有层占用的显存大小为:

s * b * h * ((34 + 5 * a * s / h ) * l / 1024 / 1024 / 1024 GB =14.45GB

对于大模型训练的时候,batch_size 和 sequence_len 会相当大,所以激活值占用现存的情况也可能相当庞大。一般大模型训练时候 batch_size 设置为 4000000 ,那么会占用 57812500GB 显存,相当感人。

总结

经过上面的介绍和计算,相比应该有了更加深刻的理解,现在我们来汇总一下上面所有部分的显存使用情况,假设我们的模型和超参数沿用上面的设置,那么:

模型输入输出:20MB(可以忽略不计)

模型参数:26GB

优化器:156GB

梯度值:26GB

激活值:14.45G (和 batch_size ,sequence_len 等超参数相关)

合计:222.45GB

也就是说全参数训练一个 llama-13B 规模的模型,需要 222GB ,换算过来就是至少需要 3 张 H800-80G 的高性能显卡,当然真实情况远比这个数字要多,因为我们计算的时候都是用的最低的理想化超参数,实际训练过程中 batch_size ,sequence_len 等超参数相当大,得需要集群才能满足训练需求。