用一行代码改进 SFT 泛化能力:DFT(Dynamic Fine-Tuning)详解

364 阅读4分钟

在大语言模型(LLM)后训练中,SFT(Supervised Fine-Tuning) 是最常用的手段:简单、高效、快速模仿专家数据。然而,很多人发现 SFT 泛化能力差,特别是在遇到数据分布变化或推理难度高的任务时,模型容易过拟合。

最近的一篇论文 [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification] 提出了一种极简的改进方法 DFT(Dynamic Fine-Tuning) ,只需 一行代码,就能在多种任务上显著提升泛化性能,甚至在某些场景下超过复杂的 RL 方法。

本文将带你深入了解:

  1. 论文的主要内容
  2. 核心创新点与关键技术
  3. 适用的实际应用场景
  4. 一个最小可运行 Demo

一、论文主要内容

作者从 强化学习(RL) 的角度重新审视了 SFT。通过数学推导,他们发现:

  • SFT 梯度等价于带隐式奖励的策略梯度(policy gradient)*

    这个隐式奖励是:

    [

    r(x, y) = \frac{\mathbf{1}[y = y^*]}{\pi_\theta(y|x)}

    ]

    • 只有完全匹配专家答案才奖励(奖励稀疏)
    • 奖励值与模型预测概率成反比(低概率 token 的梯度被放大)
  • 问题在于:

    • 当模型对专家答案赋低概率时,梯度方差会变得极大
    • 容易对少数罕见 token 过拟合
    • 泛化能力下降,特别是在难数据集或分布偏移时

二、论文提出的创新点

作者针对这个隐式奖励问题,提出了 Dynamic Fine-Tuning (DFT)

  • 核心思想*

    在计算 SFT 损失时,用当前 token 的预测概率 \pi_\theta 作为缩放因子,并阻断梯度流:

    [

    L_{\text{DFT}} = - \text{sg}(\pi_\theta(y^\t)) \cdot \log \pi\theta(y^_t)

    ]

    这里 sg() 表示 stop-gradient。

  • 作用

    • 把原本的逆概率加权(1/π)抵消掉
    • 让奖励在专家 token 上均匀分布(常数 1)
    • 稳定梯度更新,减少过拟合
  • 实现难度*

    只需一行代码改动,计算量几乎不变。


三、关键技术亮点

  1. 理论贡献

    • 精确建立了 SFT 梯度与策略梯度的数学等价关系
    • 揭示了逆概率加权是泛化差的根源
  2. 方法优势

    • 代码简单(只改损失函数)
    • 不需要奖励模型、负样本或在线交互
    • 在难数据集上避免性能下降
  3. 实验结果

    • 数学推理任务上,DFT 显著优于 SFT
    • 离线 RL 场景中,DFT 甚至超过 PPO、GRPO 等复杂方法
    • 收敛更快、早期性能更高

四、实际应用场景

DFT 适合 只有正样本(专家演示)的微调任务,尤其在以下场景效果突出:

  1. 领域适配

    • 法律、金融、医疗等专用问答
    • 学科专用 LLM(数学、物理、化学)
  2. 数据有限任务

    • 小语种翻译 / 对话
    • 垂直领域客服机器人
  3. 复杂推理

    • 链式思维(CoT)数学/逻辑推理
    • 长文档总结、代码生成
  4. RLHF 替代方案

    • 没有奖励模型或在线交互条件的情况
    • 企业内部知识库适配
  5. 多模态正样本微调

    • 图文对齐、视觉问答、OCR 等

五、最小可运行 Demo

下面给出一个可运行的 PyTorch / Hugging Face 示例,把标准 SFT 损失改成论文里的 DFT:

# demo_dft.py
# pip install transformers torch

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) 载入小模型与分词器
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.train()

# 2) 示例数据(真实场景换成专家数据)
texts = [
    "Q: 2+2 = ? A: 4",
    "Q: 3*5 = ? A: 15",
    "Q: integrate x dx = ? A: 1/2 x^2 + C",
]
enc = tokenizer(texts, return_tensors="pt", padding=True)
input_ids = enc.input_ids.to(device)
attention_mask = enc.attention_mask.to(device)
labels = input_ids.clone()

optimizer = AdamW(model.parameters(), lr=5e-5)

# 3) DFT 损失函数
def compute_dft_loss(logits, labels, attention_mask):
    log_probs = F.log_softmax(logits, dim=-1)
    labels_expanded = labels.unsqueeze(-1)
    token_logp = torch.gather(log_probs, dim=-1, index=labels_expanded).squeeze(-1)
    token_p = token_logp.exp()
    token_weights = token_p.detach()  # stop-gradient
    token_loss = - token_weights * token_logp
    mask = attention_mask.float()
    token_loss = token_loss * mask
    return token_loss.sum() / mask.sum().clamp_min(1.0)

# 4) 单步训练
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = compute_dft_loss(outputs.logits, labels, attention_mask)
print("DFT loss:", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()

关键一行:

token_weights = token_p.detach()

这就是论文的核心改动,抵消了 SFT 的逆概率加权,让奖励均匀分布。


六、总结

DFT 的意义在于:

它用一个极小的改动,解决了 SFT 泛化差的理论根源,并在实际任务中获得了大幅性能提升。

这不仅让我们更好地理解了 SFT 与 RL 的关系,也为资源有限、数据有限的微调任务提供了一个高性价比的替代方案。


如果你正在做只有正样本的 LLM 微调,不妨试试这个“一行代码”的改进,也许会让你的模型在新任务上表现惊喜。