解读《Learning Dynamics in Continual Pre-Training for Large Language Models》:持续预训练的

380 阅读4分钟

大型语言模型的持续预训练(Continual Pre-Training,简称 CPT)已成为一种流行而有效的展开方式,它允许我们在保持 LLM 通用能力的同时,展示出在特定下水领域上的强大表现。然而,如何量化地理解 CPT 过程中模型性能的变化,仍然是当前社区的难点。

这篇 ICML 2025 官方收录的论文《Learning Dynamics in Continual Pre-Training for Large Language Models》就是对这个问题的一次精精有声的解答。文章揭示了 CPT 过程中性能的一系列关键规律,并提出一套新的「损失缩放规律(scaling law)」用于量化地预测不同训练步骤下的性能变化。


一、论文概述

论文主要有两个研究目标:

  1. 能否找到一个包含 CPT 各项关键资源的规律全方程式?
  2. 能否控制 CPT 全过程中性能的时序性变化,而不仅仅是最终性能?

为了答复这两个问题,论文分析了 CPT 过程中两个关键影响因素:

  • 分布偏移值 (Distribution Shift) :从 Dpt (通用预训练数据) 转向 Dcpt (下水领域数据) 时,系统性地会导致性能均值的偏移,即一种 "损失折线"。
  • 学习率退灯 (Learning Rate Annealing) :学习率的下降会导致 loss 有折线性下降。

论文结合这两者,提出了一套 CPT 损失缩放规律:

L(t)=L0+A(S1)αC1S2(pt)C2S2(cpt)+B(1(1+ES1(cpt))β)L(t) = L_0 + A(S_1)^{-\alpha} - C_1 S_2^{(pt)} - C_2 S_2^{(cpt)} + B (1 - (1 + E S_1^{(cpt)})^{-\beta})


二、论文的创新点和关键技术

1. 提出 CPT 学习动态缩放规律 (Scaling Law)

  • 量化地表示 CPT 过程中的 loss 变化
  • 分解为“退灯规律 + 分布偏移规律”两部分
  • 分布偏移首次使用平积式 (power law)

2. 定义「隐藏预训练曲线 (Hidden PT Curve)」概念

  • CPT 实际 loss 曲线 = 从 Dpt 继续训练的隐藏曲线 + 偏移值

3. 导入「Loss Potential」概念

  • 衡量一个预训练模型在 CPT 中继续降低 loss 的能力
  • 指定 loss potential 有助于 CPT 中的定点开始

4. 可预测 OOD (预训外) 数据集表现

  • 分析表明:

    LOOD=λ1LDpt+λ2LDcptL_{OOD} = \lambda_1 L_{Dpt} + \lambda_2 L_{Dcpt}

  • 可用简单线性组合预测未要实际测试的零样本值

5. 支持开源模型 CPT 场景

  • 不需知道开源模型的预训练数据,可用通用数据作为 proxy Dpt
  • 将未知值看作拟合参数进行拟合

三、实际应用场景

场景 1:企业部署基于领域数据的 LLM

  • 例如:金融、医疗、法律
  • 用户需要培养领域能力,但不想损失基础能力
  • 利用 scaling law 定量控制 CPT 步数 / 重播比 / 学习率

场景 2:开源模型的内部应用

  • 重新培养 LLaMA / Mistral / Baichuan 等
  • 不知预训详细信息,但需 CPT
  • 用 loss potential + proxy Dpt 进行拟合和调整

场景 3:自动化调参 + 训练监控

  • 利用损失预测形成自动调参系统
  • 避免暴力网格搜索,减少训练成本

场景 4:预测未知预训表现 (OOD)

  • 较难获得大量验证数据,但想知道不同模型的表现
  • 通过 OOD = Dpt + Dcpt 的线性衡量预测 loss

四、最小可运行示例代码

以 Python 实现 CPT loss 曲线的简单预模型:

import numpy as np
import matplotlib.pyplot as plt

# 参数
L0 = 3.0
A = 0.5
alpha = 0.5
C1 = 0.3
C2 = 0.3
B = 0.5
E = 0.001
beta = 0.2

# 学习率(cosine 规律)
def lr_schedule(step, total_steps, lr_max=1e-4):
    return lr_max * 0.5 * (1 + np.cos(np.pi * step / total_steps))

def compute_loss(total_steps):
    Spt1 = 0.0
    Spt2 = 0.0
    Scpt1 = 0.0
    Scpt2 = 0.0
    losses = []
    eta_prev = 0

    for t in range(1, total_steps + 1):
        eta = lr_schedule(t, total_steps)
        Spt1 += eta
        Spt2 += (eta_prev - eta) * t
        eta_prev = eta

        base_loss = L0 + A * (Spt1 + Scpt1) ** (-alpha) - C1 * Spt2 - C2 * Scpt2
        shift_loss = B * (1 - (1 + E * Scpt1) ** (-beta))

        losses.append(base_loss + shift_loss)
        Scpt1 += eta
        Scpt2 += (eta_prev - eta) * t

    return losses

steps = 60000
loss_curve = compute_loss(steps)
plt.plot(loss_curve)
plt.xlabel("Step")
plt.ylabel("Validation Loss")
plt.title("Simulated CPT Loss Curve")
plt.grid(True)
plt.show()

结论

论文《Learning Dynamics in Continual Pre-Training for LLMs》揭示了 CPT 过程中的学习动态,并通过数学形式完整描述 loss 的变化轨迹。其提出的 scaling law 不仅可用于理论分析,更具有强大的实际应用价值:用于调参,预测,衡量 CPT 效果,甚至支持不完整信息下的