大模型训练显存计算

1,158 阅读5分钟

大模型训练显存计算

一、引言

在大模型训练和推理过程中,显存(GPU内存)的使用是一个关键问题。显存不足可能导致训练中断或推理失败,因此准确估算显存需求至关重要。本文将结合 batch_sizecutoff_len(上下文长度)、num_epochs(训练轮数)、lora_rank(LoRA参数秩)、target_modules(目标模块)、load_in_4bit(量化加载)等参数,分析显存计算的主要组成部分。


二、显存计算的主要组成部分

(一)模型参数显存

模型参数显存是指存储模型权重所需的显存。

  • 计算公式

    模型参数显存=参数数量×每个参数的字节数TP大小×PP大小\text{模型参数显存} = \frac{\text{参数数量} \times \text{每个参数的字节数}}{\text{TP大小} \times \text{PP大小}}
  • 参数数量:模型的总参数量。

  • 每个参数的字节数:取决于数据精度,例如:

    • FP32:4字节
    • FP16:2字节
    • INT8:1字节
    • INT4:0.5字节。
  • TP大小:模型并行(Tensor Parallelism)的分片数量。

  • PP大小:流水线并行(Pipeline Parallelism)的分片数量。

注意:如果启用了 LoRA(Low-Rank Adaptation),需要额外考虑 LoRA 参数的显存开销:

LoRA参数显存=lora-rank×目标模块参数数量×数据精度字节数\text{LoRA参数显存} = \text{lora-rank} \times \text{目标模块参数数量} \times \text{数据精度字节数}

(二)激活值显存

激活值显存是指模型在前向传播过程中产生的中间计算结果所需的显存。

  • 计算公式

    激活值显存=批次大小×上下文长度×隐藏层维度×激活复杂度系数×模型层数×数据精度字节数TP大小×PP大小×DP大小\text{激活值显存} = \frac{\text{批次大小} \times \text{上下文长度} \times \text{隐藏层维度} \times \text{激活复杂度系数} \times \text{模型层数} \times \text{数据精度字节数}}{\text{TP大小} \times \text{PP大小} \times \text{DP大小}}
  • 激活复杂度系数:通常为 (34 + \frac{5 \times \text{上下文长度} \times \text{注意力头数量}}{\text{隐藏层维度}}),具体值可能因模型结构而异。

  • 数据精度字节数:与模型参数显存的计算方式相同。

  • DP大小:数据并行(Data Parallelism)的分片数量。


(三)KV Cache显存

KV Cache 是在解码过程中存储历史键值对的缓存,主要用于自回归模型(如Transformer结构)。

  • 计算公式

    KV Cache 显存=2×批次大小×上下文长度×隐藏层维度×模型层数×数据精度字节数TP大小×PP大小×DP大小\text{KV Cache 显存} = \frac{2 \times \text{批次大小} \times \text{上下文长度} \times \text{隐藏层维度} \times \text{模型层数} \times \text{数据精度字节数}}{\text{TP大小} \times \text{PP大小} \times \text{DP大小}}
  • 系数2:表示需要存储Key和Value两部分缓存。


(四)优化器状态显存

优化器状态显存是指存储优化器状态(如动量和方差)所需的显存。

  • 计算公式

    优化器状态显存=可训练参数数量×每个参数的优化器状态字节数Zero优化级别对应的GPU数量\text{优化器状态显存} = \frac{\text{可训练参数数量} \times \text{每个参数的优化器状态字节数}}{\text{Zero优化级别对应的GPU数量}}
  • 每个参数的优化器状态字节数:对于 Adam 优化器,通常为 8 字节(每个参数需要两个状态变量,每个变量占用 4 字节)。

注意:如果启用了 Zero 优化,优化器状态会被分布在多个 GPU 上。


(五)梯度显存

梯度显存是指存储反向传播过程中梯度所需的显存。

  • 计算公式

    梯度显存=可训练参数数量×数据精度字节数Zero优化级别对应的GPU数量\text{梯度显存} = \frac{\text{可训练参数数量} \times \text{数据精度字节数}}{\text{Zero优化级别对应的GPU数量}}
  • 数据精度字节数:通常以 FP32 精度存储梯度,因此每个梯度占用 4 字节。


(六)临时显存

临时显存是指在训练过程中,由于临时计算、内存碎片化等原因占用的显存。这部分显存难以精确计算,但可以通过经验估算。

  • 估算公式

    临时显存=模型参数显存×临时显存比例\text{临时显存} = \text{模型参数显存} \times \text{临时显存比例}
  • 临时显存比例:通常取 0.2 到 0.5,具体值取决于模型结构和训练过程中的内存碎片化程度。


(七)总显存需求

将上述各部分显存需求相加,得到总显存需求。

  • 计算公式

    总显存=模型参数显存+激活值显存+KV Cache 显存+优化器状态显存+梯度显存+临时显存\text{总显存} = \text{模型参数显存} + \text{激活值显存} + \text{KV Cache 显存} + \text{优化器状态显存} + \text{梯度显存} + \text{临时显存}

三、并行策略对显存的影响

(一)模型并行(Tensor Parallelism, TP)

通过将模型参数分布在多个 GPU 上,减少单个 GPU 的显存压力。

  • 模型参数显存

    模型参数显存=模型参数显存TP大小\text{模型参数显存} = \frac{\text{模型参数显存}}{\text{TP大小}}
  • 激活值显存

    激活值显存=激活值显存TP大小\text{激活值显存} = \frac{\text{激活值显存}}{\text{TP大小}}

(二)流水线并行(Pipeline Parallelism, PP)

通过将模型分层分布在多个 GPU 上,减少单个 GPU 的显存压力。

  • 模型参数显存

    模型参数显存=模型参数显存PP大小\text{模型参数显存} = \frac{\text{模型参数显存}}{\text{PP大小}}
  • 梯度显存

    梯度显存=梯度显存PP大小\text{梯度显存} = \frac{\text{梯度显存}}{\text{PP大小}}

(三)数据并行(Data Parallelism, DP)

通过将数据分批分布在多个 GPU 上,减少单个 GPU 的显存压力。数据并行不会减少模型参数显存,但可以减少激活值显存。

  • 激活值显存

    激活值显存=激活值显存DP大小\text{激活值显存} = \frac{\text{激活值显存}}{\text{DP大小}}

(四)Zero优化

Zero优化通过将优化器状态和梯度分布在多个 GPU 上,减少单个 GPU 的显存压力。

  • Zero1

    总显存=模型参数显存+优化器状态显存+梯度显存GPU数量\text{总显存} = \text{模型参数显存} + \frac{\text{优化器状态显存} + \text{梯度显存}}{\text{GPU数量}}
  • Zero2

    总显存=模型参数显存+激活值显存+优化器状态显存+梯度显存GPU数量\text{总显存} = \text{模型参数显存} + \text{激活值显存} + \frac{\text{优化器状态显存} + \text{梯度显存}}{\text{GPU数量}}
  • Zero3

    总显存=激活值显存+模型参数显存+优化器状态显存+梯度显存GPU数量+LiveParams\text{总显存} = \text{激活值显存} + \frac{\text{模型参数显存} + \text{优化器状态显存} + \text{梯度显存}}{\text{GPU数量}} + \text{LiveParams}

四、量化策略对显存的影响

通过将模型参数量化为低精度格式(如 INT8 或 INT4),可以显著减少模型参数显存。

  • 模型参数显存

    模型参数显存=参数数量×量化精度字节数\text{模型参数显存} = \text{参数数量} \times \text{量化精度字节数}
  • 量化精度字节数:例如,INT8 为 1 字节,INT4 为 0.5 字节。

注意:启用量化后,可能会引入额外的校准表(Calibration Table)显存开销。


五、显存优化方法

(一)模型量化

通过将模型参数量化为低精度格式(如 INT8 或 INT4),可以显著减少模型参数显存。

(二)梯度累积

通过分批计算梯度并累积,可以减少单次计算所需的显存。

(三)梯度检查点

在反向传播过程中,通过重新计算部分激活值而不是存储它们,可以减少激活值显存的占用。

(四)混合精度训练

使用混合精度训练(如 FP16 + FP32),可以在保持精度的同时减少显存占用。

(五)分布式训练

通过将模型和数据分布在多个 GPU 上,可以减少单个 GPU 的显存压力。


六、示例计算(7B模型)

假设我们有一个 7B 参数的 Transformer 模型,具体参数如下:

  • 模型参数数量:70 亿(7e9)
  • 数据精度:FP16
  • 批次大小(Batch Size) :8
  • 上下文长度(Context Length) :1024
  • 隐藏层维度:4096
  • 模型层数:32
  • KV 头数量:32
  • 注意力头维度:128
  • 优化器:Adam
  • 并行策略:TP=2, PP=2, DP=4
  • 量化策略:无

根据上述公式,我们可以计算各部分显存需求:

(一)模型参数显存

模型参数显存=7e9×2 字节2×2=3.5 GB\text{模型参数显存} = \frac{7e9 \times 2 \text{ 字节}}{2 \times 2} = 3.5 \text{ GB}

(二)激活值显存

激活值显存=8×1024×4096×(34+5×1024×324096)×32×2 字节2×2×448 GB\text{激活值显存} = \frac{8 \times 1024 \times 4096 \times \left(34 + \frac{5 \times 1024 \times 32}{4096}\right) \times 32 \times 2 \text{ 字节}}{2 \times 2 \times 4} \approx 48 \text{ GB}

(三)KV Cache 显存

KV Cache 显存=2×8×1024×4096×32×2 字节2×2×424 GB\text{KV Cache 显存} = \frac{2 \times 8 \times 1024 \times 4096 \times 32 \times 2 \text{ 字节}}{2 \times 2 \times 4} \approx 24 \text{ GB}

(四)优化器状态显存

优化器状态显存=7e9×8 字节2×2=14 GB\text{优化器状态显存} = \frac{7e9 \times 8 \text{ 字节}}{2 \times 2} = 14 \text{ GB}

(五)梯度显存

梯度显存=7e9×4 字节2×2=7 GB\text{梯度显存} = \frac{7e9 \times 4 \text{ 字节}}{2 \times 2} = 7 \text{ GB}

(六)临时显存

假设临时显存比例为 0.3:

临时显存=3.5 GB×0.3=1.05 GB\text{临时显存} = 3.5 \text{ GB} \times 0.3 = 1.05 \text{ GB}

(七)总显存需求

总显存=3.5 GB+48 GB+24 GB+14 GB+7 GB+1.05 GB97.55 GB\text{总显存} = 3.5 \text{ GB} + 48 \text{ GB} + 24 \text{ GB} + 14 \text{ GB} + 7 \text{ GB} + 1.05 \text{ GB} \approx 97.55 \text{ GB}