大模型训练显存计算
一、引言
在大模型训练和推理过程中,显存(GPU内存)的使用是一个关键问题。显存不足可能导致训练中断或推理失败,因此准确估算显存需求至关重要。本文将结合 batch_size、cutoff_len(上下文长度)、num_epochs(训练轮数)、lora_rank(LoRA参数秩)、target_modules(目标模块)、load_in_4bit(量化加载)等参数,分析显存计算的主要组成部分。
二、显存计算的主要组成部分
(一)模型参数显存
模型参数显存是指存储模型权重所需的显存。
-
计算公式:
-
参数数量:模型的总参数量。
-
每个参数的字节数:取决于数据精度,例如:
- FP32:4字节
- FP16:2字节
- INT8:1字节
- INT4:0.5字节。
-
TP大小:模型并行(Tensor Parallelism)的分片数量。
-
PP大小:流水线并行(Pipeline Parallelism)的分片数量。
注意:如果启用了 LoRA(Low-Rank Adaptation),需要额外考虑 LoRA 参数的显存开销:
(二)激活值显存
激活值显存是指模型在前向传播过程中产生的中间计算结果所需的显存。
-
计算公式:
-
激活复杂度系数:通常为 (34 + \frac{5 \times \text{上下文长度} \times \text{注意力头数量}}{\text{隐藏层维度}}),具体值可能因模型结构而异。
-
数据精度字节数:与模型参数显存的计算方式相同。
-
DP大小:数据并行(Data Parallelism)的分片数量。
(三)KV Cache显存
KV Cache 是在解码过程中存储历史键值对的缓存,主要用于自回归模型(如Transformer结构)。
-
计算公式:
-
系数2:表示需要存储Key和Value两部分缓存。
(四)优化器状态显存
优化器状态显存是指存储优化器状态(如动量和方差)所需的显存。
-
计算公式:
-
每个参数的优化器状态字节数:对于 Adam 优化器,通常为 8 字节(每个参数需要两个状态变量,每个变量占用 4 字节)。
注意:如果启用了 Zero 优化,优化器状态会被分布在多个 GPU 上。
(五)梯度显存
梯度显存是指存储反向传播过程中梯度所需的显存。
-
计算公式:
-
数据精度字节数:通常以 FP32 精度存储梯度,因此每个梯度占用 4 字节。
(六)临时显存
临时显存是指在训练过程中,由于临时计算、内存碎片化等原因占用的显存。这部分显存难以精确计算,但可以通过经验估算。
-
估算公式:
-
临时显存比例:通常取 0.2 到 0.5,具体值取决于模型结构和训练过程中的内存碎片化程度。
(七)总显存需求
将上述各部分显存需求相加,得到总显存需求。
-
计算公式:
三、并行策略对显存的影响
(一)模型并行(Tensor Parallelism, TP)
通过将模型参数分布在多个 GPU 上,减少单个 GPU 的显存压力。
-
模型参数显存:
-
激活值显存:
(二)流水线并行(Pipeline Parallelism, PP)
通过将模型分层分布在多个 GPU 上,减少单个 GPU 的显存压力。
-
模型参数显存:
-
梯度显存:
(三)数据并行(Data Parallelism, DP)
通过将数据分批分布在多个 GPU 上,减少单个 GPU 的显存压力。数据并行不会减少模型参数显存,但可以减少激活值显存。
-
激活值显存:
(四)Zero优化
Zero优化通过将优化器状态和梯度分布在多个 GPU 上,减少单个 GPU 的显存压力。
-
Zero1:
-
Zero2:
-
Zero3:
四、量化策略对显存的影响
通过将模型参数量化为低精度格式(如 INT8 或 INT4),可以显著减少模型参数显存。
-
模型参数显存:
-
量化精度字节数:例如,INT8 为 1 字节,INT4 为 0.5 字节。
注意:启用量化后,可能会引入额外的校准表(Calibration Table)显存开销。
五、显存优化方法
(一)模型量化
通过将模型参数量化为低精度格式(如 INT8 或 INT4),可以显著减少模型参数显存。
(二)梯度累积
通过分批计算梯度并累积,可以减少单次计算所需的显存。
(三)梯度检查点
在反向传播过程中,通过重新计算部分激活值而不是存储它们,可以减少激活值显存的占用。
(四)混合精度训练
使用混合精度训练(如 FP16 + FP32),可以在保持精度的同时减少显存占用。
(五)分布式训练
通过将模型和数据分布在多个 GPU 上,可以减少单个 GPU 的显存压力。
六、示例计算(7B模型)
假设我们有一个 7B 参数的 Transformer 模型,具体参数如下:
- 模型参数数量:70 亿(7e9)
- 数据精度:FP16
- 批次大小(Batch Size) :8
- 上下文长度(Context Length) :1024
- 隐藏层维度:4096
- 模型层数:32
- KV 头数量:32
- 注意力头维度:128
- 优化器:Adam
- 并行策略:TP=2, PP=2, DP=4
- 量化策略:无
根据上述公式,我们可以计算各部分显存需求:
(一)模型参数显存
(二)激活值显存
(三)KV Cache 显存
(四)优化器状态显存
(五)梯度显存
(六)临时显存
假设临时显存比例为 0.3: