在大语言模型(LLM)后训练中,SFT(Supervised Fine-Tuning) 是最常用的手段:简单、高效、快速模仿专家数据。然而,很多人发现 SFT 泛化能力差,特别是在遇到数据分布变化或推理难度高的任务时,模型容易过拟合。
最近的一篇论文 [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification] 提出了一种极简的改进方法 DFT(Dynamic Fine-Tuning) ,只需 一行代码,就能在多种任务上显著提升泛化性能,甚至在某些场景下超过复杂的 RL 方法。
本文将带你深入了解:
- 论文的主要内容
- 核心创新点与关键技术
- 适用的实际应用场景
- 一个最小可运行 Demo
一、论文主要内容
作者从 强化学习(RL) 的角度重新审视了 SFT。通过数学推导,他们发现:
-
SFT 梯度等价于带隐式奖励的策略梯度(policy gradient)*
这个隐式奖励是:
[
r(x, y) = \frac{\mathbf{1}[y = y^*]}{\pi_\theta(y|x)}
]
- 只有完全匹配专家答案才奖励(奖励稀疏)
- 奖励值与模型预测概率成反比(低概率 token 的梯度被放大)
-
问题在于:
- 当模型对专家答案赋低概率时,梯度方差会变得极大
- 容易对少数罕见 token 过拟合
- 泛化能力下降,特别是在难数据集或分布偏移时
二、论文提出的创新点
作者针对这个隐式奖励问题,提出了 Dynamic Fine-Tuning (DFT) :
-
核心思想*
在计算 SFT 损失时,用当前 token 的预测概率 \pi_\theta 作为缩放因子,并阻断梯度流:
[
L_{\text{DFT}} = - \text{sg}(\pi_\theta(y^\t)) \cdot \log \pi\theta(y^_t)
]
这里 sg() 表示 stop-gradient。
-
作用
- 把原本的逆概率加权(1/π)抵消掉
- 让奖励在专家 token 上均匀分布(常数 1)
- 稳定梯度更新,减少过拟合
-
实现难度*
只需一行代码改动,计算量几乎不变。
三、关键技术亮点
-
理论贡献
- 精确建立了 SFT 梯度与策略梯度的数学等价关系
- 揭示了逆概率加权是泛化差的根源
-
方法优势
- 代码简单(只改损失函数)
- 不需要奖励模型、负样本或在线交互
- 在难数据集上避免性能下降
-
实验结果
- 数学推理任务上,DFT 显著优于 SFT
- 离线 RL 场景中,DFT 甚至超过 PPO、GRPO 等复杂方法
- 收敛更快、早期性能更高
四、实际应用场景
DFT 适合 只有正样本(专家演示)的微调任务,尤其在以下场景效果突出:
-
领域适配
- 法律、金融、医疗等专用问答
- 学科专用 LLM(数学、物理、化学)
-
数据有限任务
- 小语种翻译 / 对话
- 垂直领域客服机器人
-
复杂推理
- 链式思维(CoT)数学/逻辑推理
- 长文档总结、代码生成
-
RLHF 替代方案
- 没有奖励模型或在线交互条件的情况
- 企业内部知识库适配
-
多模态正样本微调
- 图文对齐、视觉问答、OCR 等
五、最小可运行 Demo
下面给出一个可运行的 PyTorch / Hugging Face 示例,把标准 SFT 损失改成论文里的 DFT:
# demo_dft.py
# pip install transformers torch
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1) 载入小模型与分词器
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.train()
# 2) 示例数据(真实场景换成专家数据)
texts = [
"Q: 2+2 = ? A: 4",
"Q: 3*5 = ? A: 15",
"Q: integrate x dx = ? A: 1/2 x^2 + C",
]
enc = tokenizer(texts, return_tensors="pt", padding=True)
input_ids = enc.input_ids.to(device)
attention_mask = enc.attention_mask.to(device)
labels = input_ids.clone()
optimizer = AdamW(model.parameters(), lr=5e-5)
# 3) DFT 损失函数
def compute_dft_loss(logits, labels, attention_mask):
log_probs = F.log_softmax(logits, dim=-1)
labels_expanded = labels.unsqueeze(-1)
token_logp = torch.gather(log_probs, dim=-1, index=labels_expanded).squeeze(-1)
token_p = token_logp.exp()
token_weights = token_p.detach() # stop-gradient
token_loss = - token_weights * token_logp
mask = attention_mask.float()
token_loss = token_loss * mask
return token_loss.sum() / mask.sum().clamp_min(1.0)
# 4) 单步训练
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = compute_dft_loss(outputs.logits, labels, attention_mask)
print("DFT loss:", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
关键一行:
token_weights = token_p.detach()
这就是论文的核心改动,抵消了 SFT 的逆概率加权,让奖励均匀分布。
六、总结
DFT 的意义在于:
它用一个极小的改动,解决了 SFT 泛化差的理论根源,并在实际任务中获得了大幅性能提升。
这不仅让我们更好地理解了 SFT 与 RL 的关系,也为资源有限、数据有限的微调任务提供了一个高性价比的替代方案。
如果你正在做只有正样本的 LLM 微调,不妨试试这个“一行代码”的改进,也许会让你的模型在新任务上表现惊喜。