batch size、sequence length 对显存的非线性影响

18 阅读5分钟

几乎所有 OOM,都是“我以为还能再加一点”

如果你做过大模型微调,你一定经历过这种时刻:

  • batch size 调小一点 → 能跑

  • sequence length 加一点 → 还能跑

  • 两个一起微调 → 显存直接炸

 

你看着监控面板,心里非常困惑:

 

“不对啊,我是算过的。”

 

这正是问题的关键。

 

**你算的,是线性的;

而显存消耗,从来不是。**

 

31.png 工程师心里计算的显存 vs 实际显存曲线对比

 

一个必须先说清楚的结论

在正式展开之前,我先把这篇文章最重要的一句话写出来:

 

**batch size 和 sequence length,

不是两个独立的显存旋钮,

而是一个相互放大的乘法因子。**

 

如果你还在用:

  • “batch 翻倍,显存翻倍”

  • “长度翻倍,显存翻倍”

 

这样的直觉来理解显存,

那你几乎一定会被 OOM 教育。

 

第一层误解:把显存消耗理解成“参数规模问题”

很多人一说显存,第一反应是:

  • 模型多大

  • 参数多少

  • 是不是该用 LoRA / QLoRA

 

这些当然重要,

但它们只决定了显存的底座

 

真正让你在训练时反复爆显存的,往往不是参数,而是:

 

中间态(activations)。

 

而 batch size 和 sequence length,

正是中间态的最大放大器。

 

第二层:为什么 sequence length 比你想象中“更贵”

很多人会觉得:

 

“sequence length 只是多几个 token,

显存应该线性增加吧?”

 

这是一个非常危险的直觉。

 

一个必须面对的事实

在 Transformer 里,sequence length 影响的不只是:

  • embedding

  • attention 输入

 

而是:

  • attention score

  • KV cache

  • 每一层的中间激活

 

尤其是 self-attention

它的计算和存储复杂度是:

 

O(L²)

 

也就是说:

  • length 从 1024 → 2048

  • token 数翻倍

  • attention 相关显存,可能直接 ×4

 

这就是为什么你“只是把 max_length 调大了一点”,

显存却突然不讲道理。

 

32.png sequence length ↑ → attention 显存平方增长

 

第三层:batch size 为什么会“乘上” sequence length

单看 batch size,好像也很直观:

  • batch ×2 → 数据 ×2

 

但问题在于:

 

**batch size 决定的是:

同一时间,有多少条序列在走完整前向和反向。**

 

于是显存里同时存在的,是:

 


batch_size × sequence_length × hidden_dim × layer_count

 

这不是加法,是堆叠

 

当你把 batch size 和 sequence length 同时往上拉时,

你做的事情其实是:

 

让显存同时承载更多、更长、而且还没释放的中间态。

 

第四层:非线性真正出现的地方——反向传播

如果只是前向,其实很多时候还能勉强扛住。

 

真正让显存爆炸的,是:

 

反向传播阶段。

 

原因很简单:

  • 前向:可以边算边丢

  • 反向:必须留住中间态

 

这意味着:

  • batch 越大 → 需要保留的中间结果越多

  • length 越长 → 每一层要保存的激活越重

 

于是显存曲线会出现一个非常典型的形态:

 

**前向看着还行,

反向直接炸。**

 

33.png 前向 vs 反向 显存占用对比

 

第五层:为什么“只加一点点”,却跨过了临界点

这是最让人崩溃的地方。

 

你可能经历过:

  • batch=2,length=2048,OK

  • batch=3,length=2048,OOM

 

你会觉得:

 

“就多了一条样本,怎么就炸了?”

 

原因在于:

 

**显存不是连续可用的,

而是存在碎片和临界点的。**

 

当你跨过某个阈值:

  • CUDA 需要分配一整块新的 buffer

  • allocator 找不到足够连续空间

  • 于是直接失败

 

这就是为什么:

 

**显存不是“慢慢用完”的,

而是“突然不够用”的。**

 

第六层:梯度累积,为什么没你想得那么“省”

很多人会说:

 

“batch 太大?那我用 gradient accumulation。”

 

这确实能缓解一部分问题,

但它并不是免费午餐。

 

因为:

  • accumulation 并不会减少单步的 activation 显存

  • 它只是减少了一次 forward/backward 中的 batch

 

如果你的 OOM 来自:

  • sequence length 太长

  • attention 中间态太重

 

那梯度累积几乎救不了你

 

这也是为什么有些人会困惑:

 

“我 batch 已经很小了,为什么还 OOM?”

 

答案往往是:

 

真正压垮显存的,是 length,不是 batch。

 

第七层:评估阶段为什么反而更容易炸显存

这是一个很多人没想到的坑。

 

在评估时,你可能会:

  • 关掉 dropout

  • 不算 loss

  • 以为显存会更省

 

但实际情况是:

  • 推理 batch 往往更大

  • sequence length 往往更长

  • KV cache 占用持续存在

 

于是你会看到:

 

训练能跑,评估反而 OOM。

 

这不是 bug,

而是你在评估阶段:

 

把 batch × length 推到了另一个非线性区域。

 

一个非常真实的“显存误判路径”

 


我算过参数显存 → 应该够

我减过 batch → 应该稳

我只加了点 length → 应该没事

OOM

 

注意:

每一步判断,单独看都“合理”。

 

错的是:

你在用线性思维,面对非线性系统。

 

那工程上该怎么“正确理解” batch 和 length?

不是给你一个公式,

而是给你一个更安全的判断方式

 

**sequence length 决定了“单样本的重量”,

batch size 决定了“同时搬多少个”。**

 

当你不知道哪里该省的时候,优先问:

  • 单条样本,是不是太重了?

  • attention 的 L² 是否已经不可接受?

 

很多时候:

 

减 length,比减 batch 更有效。

 

一个非常实用的显存自检问题

在你准备调 batch 或 length 之前,可以问自己一句话:

 

**如果显存炸了,

我更希望模型“少看几条”,

还是“每条看短一点”?**

 

如果你无法回答,

说明你对当前显存结构还不够清楚。

 

很多团队在 batch size 和 sequence length 上反复试探显存上限,本质问题不是参数没算清,而是缺乏对“中间态显存结构”的直观感知。用LLaMA-Factory online观察不同 batch / length 配置下的训练行为,更容易理解:是哪一部分在非线性放大,而不是盲目试错。

 

总结:显存不是被“用完”的,而是被“触发”的

我用一句话,把这篇文章彻底收住:

 

**batch size 和 sequence length

并不是慢慢吃掉显存的,

而是在某个点上,

一起把你推下悬崖。**

 

当你开始:

  • 把显存理解成结构问题

  • 把 length 当成一等公民

  • 放弃“线性估算”的安全感

 

你才真正开始工程化地调显存