10-让模型更小更聪明,学而不忘:知识蒸馏与持续学习

8 阅读6分钟

开篇:两个核心问题

在大模型的实际应用中,我们常常面临两个看似矛盾的需求:

问题1:模型太大怎么办?

GPT-4: 1.76万亿参数  推理成本高昂
Llama 3 70B: 70B参数  需要多张A100
  
能否压缩到更小的模型,但保持性能?

示例

场景问题
边缘设备部署手机/IoT设备无法运行70B模型
实时响应大模型推理延迟高(秒级)
成本控制API调用费用过高

解决方案知识蒸馏(Knowledge Distillation) - 用小模型学习大模型的知识


问题2:模型会"遗忘"怎么办?

场景:持续学习新任务
  任务A(医疗问答) → 训练后效果好
  任务B(法律咨询) → 训练后效果好,但任务A性能下降!
  任务C(代码生成) → 训练后效果好,但任务AB都变差!

示例

训练阶段医疗问答准确率法律咨询准确率代码生成准确率
只训练医疗90%--
加入法律训练65% ⚠️88%-
加入代码训练45% ⚠️⚠️62% ⚠️85%

这就是灾难性遗忘(Catastrophic Forgetting)

解决方案回放机制(Experience Replay) - 让模型记住旧知识


第一部分:知识蒸馏(Knowledge Distillation)

什么是知识蒸馏?

核心思想:让小模型(学生)学习大模型(教师)的"思考过程",而不仅仅是最终答案。

类比:学习数学

传统训练(学标准答案):
  题目:"2 + 2 = ?"
  标签:4
  学生学到:2 + 2 = 4

知识蒸馏(学老师的思路):
  题目:"2 + 2 = ?"
  老师的想法:
    - 4的概率:95%(最可能)
    - 3的概率:3%(也有点像)
    - 5的概率:2%(数字接近)
    - 其他:0.1%
  学生学到:不仅是答案,还有"为什么其他答案不对"的信息

关键洞察:大模型的**软标签(Soft Label)**包含了比硬标签(Hard Label)更丰富的信息。

蒸馏的基本原理

硬标签 vs 软标签

硬标签(One-hot)

# 问题:"天空是什么颜色?"
hard_label = {
    "蓝色": 1.0,
    "红色": 0.0,
    "绿色": 0.0,
    "黑色": 0.0
}

软标签(概率分布)

# 大模型的输出
soft_label = {
    "蓝色": 0.85,  # 大多数情况
    "黑色": 0.10,  # 夜晚的天空
    "红色": 0.03,  # 日出/日落
    "灰色": 0.02   # 阴天
}

优势:软标签包含了上下文信息不确定性

深入理解:LLM中的硬标签vs软标签

上面的例子是分类任务,但LLM是生成任务,硬标签和软标签具体是什么样子呢?

关键理解:LLM的每一步生成本质上是分类任务

虽然LLM看起来是"生成"任务,但每预测一个词,本质上是在整个词表(通常50,000个词)上做分类:

# LLM的生成过程
输入:"今天天气"
↓
需要预测:下一个词是什么?
↓
这是一个 50,000 分类问题(从词表中选一个词)

具体例子:预测下一个词

场景:输入 "今天天气",预测下一个词

硬标签(传统训练)

# 训练数据
输入:  "今天天气"
标签:  "很好"  # Token ID: 1234

# One-hot 硬标签(词表大小=50000)
hard_label = [0, 0, 0, ..., 1, ..., 0, 0]
                           ↑
                      位置12341
                      其他49999个位置都是0

# 损失函数
loss = CrossEntropy(model_output, hard_label)
# 只关心位置1234的预测概率

问题

  • ❌ 只告诉模型"正确答案是'很好'"
  • ❌ 没有告诉模型为什么"不错"也可以(同义)
  • ❌ 没有告诉模型为什么"真好"也合理(相似表达)
  • ❌ 没有告诉模型为什么"糟糕"不对(反义)

软标签(教师模型的输出)

# 大模型(教师)的实际输出
输入: "今天天气"

# 教师模型的 logits(未归一化的分数)
teacher_logits = {
    "很好":  8.5,
    "不错":  8.2,
    "真好":  7.8,
    "挺好":  7.5,
    "还行":  6.2,
    "一般":  4.5,
    "不好":  2.1,
    "糟糕":  1.3,
    ... # 其他49992个词
}

# 经过 softmax(T=1) 得到概率(软标签)
soft_label = {
    "很好":  0.52,  # 最可能
    "不错":  0.31,  # 也很可能
    "真好":  0.10,  # 还算可能
    "挺好":  0.04,  # 有点可能
    "还行":  0.02,  # 稍微可能
    "一般":  0.008, # 不太可能
    "不好":  0.001, # 很不可能
    "糟糕":  0.0005,# 几乎不可能
    其他词:   0.0015 # 极不可能
}

丰富的信息

  1. ✅ "很好"是最佳答案(52%)
  2. ✅ "不错"也很合理(31%)— 同义词信息
  3. ✅ "真好"、"挺好"可以接受 — 相似表达
  4. ✅ "一般"不太好但不是完全错误 — 中性词
  5. ✅ "糟糕"几乎不可能 — 反义词

完整的生成过程对比

让我们看一个完整句子的生成:

输入:"写一首关于春天的诗"

传统训练(硬标签)

步骤1:
输入: "写一首关于春天的诗\n"
硬标签:"春"
模型学习:第一个字必须是"春" ✗

步骤2:
输入: "写一首关于春天的诗\n春"
硬标签:"风"
模型学习:第二个字必须是"风" ✗

步骤3:
输入: "写一首关于春天的诗\n春风"
硬标签:"拂"
模型学习:第三个字必须是"拂" ✗

结果:每一步只知道"正确答案",不知道其他选项为什么不对
→ 缺乏灵活性,容易过拟合

知识蒸馏(软标签)

步骤1:
输入: "写一首关于春天的诗\n"

教师模型的软标签(概率分布):
{
    "春": 0.45,  # 最常见的开头
    "暖": 0.15,  # 也可以
    "阳": 0.10,  # 比较常见
    "万": 0.08,  # "万物复苏"
    "东": 0.05,  # "东风"
    "柳": 0.03,  # 也是春天的意象
    其他: 0.14
}

学生模型学习到:
✓ "春"字最好(直接点题)
✓ "暖""阳"也不错(温暖的意象)
✓ "万"可以(万物复苏)
✓ 其他字可能性很低

步骤2:
输入: "写一首关于春天的诗\n春"

教师模型的软标签:
{
    "风": 0.35,  # "春风"搭配
    "光": 0.20,  # "春光"
    "雨": 0.15,  # "春雨"
    "日": 0.10,  # "春日"
    "意": 0.08,  # "春意"
    "花": 0.05,  # "春花"
    其他: 0.07
}

学生模型学习到:
✓ "风"最佳(春风是经典搭配)
✓ "光""雨""日"都是好的选择
✓ 这些词都和"春"搭配良好
✓ 词语搭配的概率分布

结果:学生模型理解了多种可能性,生成更灵活、自然

数学形式对比

硬标签 Loss

Lhard=logPstudent(ytruex)L_{\text{hard}} = -\log P_{\text{student}}(y_{\text{true}} | x)

只优化"正确答案"的概率。

软标签 Loss(蒸馏)

Lsoft=i=1VPteacher(yix)logPstudent(yix)L_{\text{soft}} = -\sum_{i=1}^{V} P_{\text{teacher}}(y_i | x) \log P_{\text{student}}(y_i | x)

其中:

  • VV:词表大小(如50,000)
  • Pteacher(yix)P_{\text{teacher}}(y_i | x):教师模型对第ii个词的预测概率
  • Pstudent(yix)P_{\text{student}}(y_i | x):学生模型对第ii个词的预测概率

含义:学生模型要学习教师模型的整个概率分布(50,000个词的概率),而不仅仅是最高概率的那个词。

实际案例:翻译任务

任务:翻译 "The weather is nice today"

硬标签训练的模型

训练数据只有一个答案:
"The weather is nice today""今天天气很好"

模型学到的:
步骤1: "The""今" (概率1.0)
步骤2: "weather""天" (概率1.0)
...

结果:模型只会生成这一种翻译 ✗
缺乏灵活性,不能处理同义表达

软标签训练的模型

教师模型(GPT-4)的输出分布:

步骤1: "The" →
{
    "今": 0.60,    # 最常见
    "今日": 0.25,  # 也可以
    "今儿": 0.08,  # 口语化
    其他: 0.07
}

步骤4: "nice" →
{
    "很": 0.35,  # "很好"
    "不": 0.30,  # "不错"
    "真": 0.15,  # "真好"
    "挺": 0.10,  # "挺好"
    其他: 0.10
}

学生模型学到:
✓ 多种表达方式的概率
✓ "今天天气很好" (最常见)
✓ "今天天气不错" (也很好)
✓ "今日天气真好" (稍正式)
✓ 不同表达的合理性 ✓

代码实现对比

import torch
import torch.nn.functional as F

vocab_size = 50000  # 词表大小

# 输入:"今天天气",预测下一个词
input_text = "今天天气"

# ===== 硬标签训练 =====
# 真实标签:"很好" (token_id = 1234)
hard_label = torch.zeros(vocab_size)
hard_label[1234] = 1.0  # one-hot

student_logits = student_model(input_text)
loss_hard = F.cross_entropy(
    student_logits.unsqueeze(0),
    torch.tensor([1234])
)
# 只关心位置1234的概率 ✗

# ===== 软标签训练(蒸馏)=====
with torch.no_grad():
    teacher_logits = teacher_model(input_text)

# 使用温度软化
temperature = 2.0
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
# teacher_probs 形状: [50000]
# 例如:[0.0001, 0.0002, ..., 0.52(1234), 0.31(5678), ...]
#  ↑ 所有50000个位置都有概率值!✓

student_probs_soft = F.softmax(
    student_logits / temperature,
    dim=-1
)

# 软标签损失(KL散度)
loss_soft = F.kl_div(
    student_probs_soft.log(),
    teacher_probs,
    reduction='batchmean'
) * (temperature ** 2)

# 最终损失:软标签 + 硬标签
loss = 0.7 * loss_soft + 0.3 * loss_hard

软标签的实际价值总结

在LLM中,软标签传递了:

  1. 同义词信息

    • "很好"、"不错"、"真好"都是合理答案
    • 概率反映了它们的相似度
  2. 词语搭配知识

    • "春"后面接"风"、"光"、"雨"都合理
    • 概率反映了搭配的常见程度
  3. 上下文理解

    • 不同上下文下,同一个位置的词分布不同
    • 教师模型的理解被传递给学生
  4. 生成多样性

    • 学生知道多个选择都可行
    • 不会过拟合到单一答案

类比总结

硬标签:
告诉你 "2+2=4"
→ 只知道答案是4

软标签:
告诉你 "2+2=4(95%), 也可能是3.9(3%)或4.1(2%),
        但绝不是10(0.001%)"
→ 理解了整个数值空间的合理性

这就是为什么知识蒸馏在LLM中如此有效——软标签包含了教师模型对整个词表空间的深刻理解!

温度(Temperature)

为了让软标签更"软",使用温度参数 TT

softmax(zi,T)=exp(zi/T)jexp(zj/T)\text{softmax}(z_i, T) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}

效果

温度 TT效果分布特征
T=1T = 1标准softmax正常分布
T>1T > 1软化分布更平滑,信息更丰富
TT \to \infty均匀分布所有类别概率接近
T<1T < 1锐化分布接近one-hot

示例

import torch
import torch.nn.functional as F

logits = torch.tensor([4.0, 2.0, 1.0, 0.1])

# T=1: 标准softmax
p_t1 = F.softmax(logits, dim=0)
print(f"T=1:   {p_t1}")
# 输出: [0.8360, 0.1131, 0.0416, 0.0092]

# T=3: 软化
p_t3 = F.softmax(logits / 3, dim=0)
print(f"T=3:   {p_t3}")
# 输出: [0.5021, 0.2447, 0.1512, 0.1020]

# T=10: 更软
p_t10 = F.softmax(logits / 10, dim=0)
print(f"T=10:  {p_t10}")
# 输出: [0.3174, 0.2631, 0.2370, 0.1825]

观察:温度越高,概率分布越平滑,包含的"暗知识(Dark Knowledge)"越多。


方法1:Logits蒸馏(经典方法)

原理

学生模型同时学习两个目标

  1. 与真实标签的匹配(硬标签损失)
  2. 与教师模型的匹配(软标签损失)

数据流可视化

输入:"这部电影还不错"
    |
    |  (同时输入两个模型)
    |
    ├──────────────────────┬──────────────────────┐
    ↓                      ↓                      ↓
教师模型                学生模型              真实标签
[4.5, 0.3, 1.2]       [3.8, 0.8, 1.0]        "正面"
    |                      |                      |
    | T=3                  | T=3                  |
    ↓                      ↓                      |
[0.65, 0.15, 0.20]    [0.55, 0.22, 0.23]        |
  (软标签)               (学生软预测)            |
    |                      |                      |
    └──────KL散度─────────┘                      |
          ↓                                       |
      软标签损失                                  |
       = 0.042                                    |
                                                  |
                            学生logits            |
                          [3.8, 0.8, 1.0]         |
                                |   T=1           |
                                ↓                 |
                          [0.85, 0.04, 0.11]      |
                                |                 |
                                └──交叉熵─────────┘
                                      ↓
                                  硬标签损失
                                   = 0.163
                                      |
        ┌─────────────────────────────┴─────────────────────────────┐
        ↓                                                             ↓
    α * T² * 软标签损失                              (1-α) * 硬标签损失
  = 0.7 * 9 * 0.042                                = 0.3 * 0.163
  = 0.265                                          = 0.049
        |                                                             |
        └─────────────────────────┬───────────────────────────────┘
                                  ↓
                            总损失 = 0.314
                                  ↓
                            反向传播,更新学生模型

关键理解

  • 学生模型只前向传播一次,得到一个logits输出
  • 这个输出参与两个损失计算
    • 软化后(T=3)与教师的软化输出比较
    • 标准化(T=1)与真实标签比较
  • 两个损失加权求和,共同指导学生学习

损失函数

LKD=αLsoft+(1α)Lhard\mathcal{L}_{\text{KD}} = \alpha \cdot \mathcal{L}_{\text{soft}} + (1-\alpha) \cdot \mathcal{L}_{\text{hard}}

其中:

软标签损失

Lsoft=T2KL(softmax(zST)softmax(zTT))\mathcal{L}_{\text{soft}} = T^2 \cdot \text{KL}\left( \text{softmax}\left(\frac{z^S}{T}\right) \, \Big\| \, \text{softmax}\left(\frac{z^T}{T}\right) \right)

硬标签损失

Lhard=CrossEntropy(zS,y)\mathcal{L}_{\text{hard}} = \text{CrossEntropy}(z^S, y)

参数:

  • zTz^T:教师模型的logits
  • zSz^S:学生模型的logits
  • yy:真实标签
  • TT:温度(通常2-10)
  • α\alpha:平衡系数(通常0.7-0.9)
  • T2T^2:补偿温度导致的梯度缩放

完整示例:理解损失计算

关键点:学生模型只有一个输出,但用两种方式评估

假设一个情感分类任务(正面/负面/中性),输入:"这部电影还不错"

# ===== 第1步:教师模型推理 =====
teacher_logits = teacher_model("这部电影还不错")
# 输出 logits(未归一化的分数):
# [正面: 4.5, 负面: 0.3, 中性: 1.2]

# 教师的概率分布(T=1):
teacher_probs_T1 = softmax(teacher_logits / 1)
# [正面: 0.91, 负面: 0.01, 中性: 0.08]
# → 教师非常确信是"正面"

# 教师的软化概率(T=3,用于蒸馏):
teacher_probs_T3 = softmax(teacher_logits / 3)
# [正面: 0.65, 负面: 0.15, 中性: 0.20]
# → 软化后,"中性"和"负面"的概率提高了,包含更多信息


# ===== 第2步:学生模型推理(同一个输入) =====
student_logits = student_model("这部电影还不错")
# 输出 logits:
# [正面: 3.8, 负面: 0.8, 中性: 1.0]

# 学生的概率分布(T=1):
student_probs_T1 = softmax(student_logits / 1)
# [正面: 0.85, 负面: 0.04, 中性: 0.11]

# 学生的软化概率(T=3):
student_probs_T3 = softmax(student_logits / 3)
# [正面: 0.55, 负面: 0.22, 中性: 0.23]


# ===== 第3步:计算两个损失 =====

# 损失1:硬标签损失(与真实标签比较)
true_label = "正面"  # 人类标注的真实答案
hard_loss = CrossEntropy(student_probs_T1, true_label)
# = -log(0.85) = 0.163
# 评估:学生对正确答案(正面)的预测准确性

# 损失2:软标签损失(与教师的思考过程比较)
soft_loss = KL_Divergence(student_probs_T3, teacher_probs_T3)
# = 0.65*log(0.65/0.55) + 0.15*log(0.15/0.22) + 0.20*log(0.20/0.23)
# ≈ 0.042
# 评估:学生的思考过程与教师的相似度

# 损失3:总损失(加权组合)
alpha = 0.7  # 软标签权重
total_loss = alpha * (3^2) * soft_loss + (1 - alpha) * hard_loss
#          = 0.7 * 9 * 0.042 + 0.3 * 0.163
#          = 0.265 + 0.049
#          = 0.314

重点理解

  1. 学生模型只输出一次student_logits = [3.8, 0.8, 1.0]

  2. 这一个输出被两种方式使用

    • T=1 的概率与真实标签比较 → 硬标签损失
    • T=3 的概率与教师比较 → 软标签损失
  3. 为什么两个都需要?

    只用硬标签损失:
      学生只学会"这是正面评价"(结果)
      ✗ 不知道为什么不是中性或负面
    
    只用软标签损失:
      学生学会了教师的思考模式(过程)
      ✗ 但可能偏离真实标签(如果教师也会犯错)
    
    两者结合:
      ✓ 既学会了正确答案(硬标签)
      ✓ 又学会了思考过程(软标签)
    
  4. 软标签包含的"额外信息"

    硬标签告诉学生:
      "答案是正面" [1, 0, 0]
    
    软标签告诉学生:
      "正面可能性65%,但也有20%可能是中性(因为'还不错'
       比较温和),15%可能有些负面情绪"
      [0.65, 0.15, 0.20]
    
    学生学到:
      - 主要特征:积极词汇 → 正面
      - 细微差异:"还不错""很棒"更温和 → 也有中性成分
      - 边界情况:如何区分"温和正面""中性"
    

常见疑问解答

Q1: 为什么不只用软标签损失?教师已经学到了正确答案。

A: 因为教师也可能犯错,或者在某些样本上不够自信。硬标签提供了"ground truth",确保学生不会被教师的错误误导。

# 例子:教师在困难样本上可能不确定
教师预测: [正面: 0.48, 负面: 0.52]  # 教师认为略偏负面
真实标签: "正面"                    # 实际是正面

只用软标签 → 学生学到"这是负面的"(错误)
加上硬标签 → 学生知道真实答案是正面,但也学到这是个"接近边界"的案例

Q2: 为什么不只用硬标签损失?传统训练不就是这样吗?

A: 硬标签只提供了"对/错"的信息,丢失了很多细节:

硬标签: [1, 0, 0]  # 只知道第一个类是对的
软标签: [0.85, 0.10, 0.05]  # 知道第一个类最可能,第二个类也有点像,第三个类完全不像

学生从软标签学到:
- 类别之间的相似性(哪些类容易混淆)
- 决策边界的位置(多接近才算"像")
- 不确定性的估计(模型有多自信)

Q3: 学生模型到底输出什么?

A: 学生模型只输出一个东西:logits向量

# 完整流程
input = "这部电影还不错"

# 学生只做一次前向传播
student_logits = student_model(input)
# → [3.8, 0.8, 1.0]  (就这一个输出!)

# 然后这个输出被用于两个损失计算:
# 用法1:软化后与教师比较
student_soft = softmax(student_logits / 3)  # T=3
loss_soft = KL(student_soft, teacher_soft)

# 用法2:标准化后与标签比较
student_normal = softmax(student_logits / 1)  # T=1
loss_hard = CE(student_normal, true_label)

# 两个损失加权求和
total_loss = 0.7 * loss_soft + 0.3 * loss_hard

# 反向传播,更新 student_model 的参数
total_loss.backward()

Q4: 为什么软标签损失要乘以 T²?

A: 因为温度会缩放梯度,T² 是为了补偿:

# 当T增大时,softmax的输出变化变小
# 导致梯度也变小
# 乘以T²可以保持梯度的尺度与硬标签损失相当

# 数学上:
∂L_soft/∂z ∝ 1/T²  (温度缩放导致梯度变小)
L_soft × T² → 梯度恢复正常尺度

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7):
        """
        Args:
            temperature: 软化温度
            alpha: 软标签损失的权重(硬标签权重为 1-alpha)
        """
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        """
        Args:
            student_logits: 学生模型输出 [batch, num_classes]
            teacher_logits: 教师模型输出 [batch, num_classes]
            labels: 真实标签 [batch]

        关键:student_logits是同一个输出,被两种方式使用:
            - 软化后与教师比较(学习思考过程)
            - 直接与真实标签比较(学习正确答案)
        """
        # 1. 软标签损失(KL散度)
        # 用高温softmax软化,让概率分布更平滑
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_div(student_soft, teacher_soft) * (self.temperature ** 2)
        # 目标:让学生的概率分布接近教师的概率分布

        # 2. 硬标签损失(交叉熵)
        # 用标准softmax(T=1),与真实标签比较
        hard_loss = self.ce(student_logits, labels)
        # 目标:让学生预测正确的类别

        # 3. 总损失(加权组合)
        loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
        # alpha=0.7 表示:70%关注教师的思考过程,30%关注正确答案

        return loss, soft_loss, hard_loss

# 使用示例
teacher_model = load_large_model()  # 70B模型
student_model = create_small_model()  # 7B模型

kd_loss = KnowledgeDistillationLoss(temperature=3.0, alpha=0.8)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-5)

for batch in dataloader:
    inputs, labels = batch

    # 教师模型推理(不计算梯度)
    with torch.no_grad():
        teacher_logits = teacher_model(inputs)

    # 学生模型推理
    student_logits = student_model(inputs)

    # 计算蒸馏损失
    loss, soft_loss, hard_loss = kd_loss(
        student_logits, teacher_logits, labels
    )

    # 反向传播
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"Total: {loss:.4f}, Soft: {soft_loss:.4f}, Hard: {hard_loss:.4f}")

为什么有效?

信息量对比

硬标签:
  "北京是中国的首都" → 标签: 正确 (1.0)
  信息量:1 bit

软标签:
  "北京是中国的首都" → 教师模型输出:
    - 正确: 0.95
    - 错误但相关: 0.03"北京在中国")
    - 错误且无关: 0.02"北京不存在")
  信息量:~5 bits(更丰富)

学生学到

  • ✅ 北京确实是首都(主要信息)
  • ✅ 一些接近正确的表述也有道理(细微差异)
  • ✅ 某些表述明显错误(边界信息)

方法2:特征蒸馏(白盒方法)

原理

不仅学习最终输出,还学习中间层的特征表示

架构对比

Logits蒸馏:
  Teacher: InputTransformerLogitsStudent: InputTransformerLogits
           (只匹配最后输出)

特征蒸馏:
  Teacher: InputLayer1Layer2 → ... → Logits
                     ↓        ↓              ↓
  Student: InputLayer1Layer2 → ... → Logits
           (匹配多个中间层)

损失函数

LFeature=lLλlHlSProj(HlT)2\mathcal{L}_{\text{Feature}} = \sum_{l \in \mathcal{L}} \lambda_l \cdot \| H_l^S - \text{Proj}(H_l^T) \|^2

其中:

  • HlTH_l^T:教师模型第 ll 层的隐藏状态
  • HlSH_l^S:学生模型第 ll 层的隐藏状态
  • Proj\text{Proj}:投影层(因为教师和学生的维度可能不同)
  • L\mathcal{L}:选择的层(通常选几个关键层)

代码实现

class FeatureDistillationLoss(nn.Module):
    def __init__(self, teacher_dim, student_dim, num_layers=4):
        super().__init__()
        # 投影层:将教师特征投影到学生维度
        self.projections = nn.ModuleList([
            nn.Linear(teacher_dim, student_dim)
            for _ in range(num_layers)
        ])
        self.mse = nn.MSELoss()

    def forward(self, student_features, teacher_features):
        """
        Args:
            student_features: List of [batch, seq_len, student_dim]
            teacher_features: List of [batch, seq_len, teacher_dim]
        """
        total_loss = 0

        for i, (s_feat, t_feat) in enumerate(
            zip(student_features, teacher_features)
        ):
            # 投影教师特征
            t_feat_proj = self.projections[i](t_feat)

            # MSE损失
            loss = self.mse(s_feat, t_feat_proj)
            total_loss += loss

        return total_loss / len(student_features)

# 使用示例
class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([...])
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, return_features=False):
        features = []
        hidden = x

        for layer in self.layers:
            hidden = layer(hidden)
            if return_features:
                features.append(hidden)

        logits = self.lm_head(hidden)

        if return_features:
            return logits, features
        return logits

# 训练
feature_loss_fn = FeatureDistillationLoss(
    teacher_dim=4096,  # 教师隐藏维度
    student_dim=2048   # 学生隐藏维度
)

for batch in dataloader:
    inputs, labels = batch

    # 教师前向(获取中间特征)
    with torch.no_grad():
        teacher_logits, teacher_features = teacher_model(
            inputs, return_features=True
        )

    # 学生前向
    student_logits, student_features = student_model(
        inputs, return_features=True
    )

    # 特征蒸馏损失
    feature_loss = feature_loss_fn(student_features, teacher_features)

    # Logits蒸馏损失
    logit_loss = kd_loss(student_logits, teacher_logits, labels)

    # 总损失
    loss = logit_loss + 0.5 * feature_loss

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

优势

  • ✅ 学习更深层的知识表示
  • ✅ 更好的泛化能力
  • ✅ 训练更稳定

劣势

  • ⚠️ 需要访问教师模型内部(白盒)
  • ⚠️ 计算开销更大
  • ⚠️ 架构设计更复杂

方法3:响应蒸馏(黑盒方法)

原理

完全不需要教师模型的内部结构,只使用教师模型的文本输出。

适用场景

  • 教师模型是API(如GPT-4、Claude)
  • 教师模型不开源
  • 无法获取教师模型的logits

流程

步骤1:用教师模型生成高质量数据
  输入: "解释量子计算"
  教师输出: "量子计算是利用量子力学原理进行计算的技术..."

步骤2:学生模型学习模仿教师的输出
  训练数据: (输入, 教师输出)
  目标: 最小化学生输出与教师输出的差异

实现方法

方法A:直接监督学习

# 1. 收集教师响应
teacher_responses = []

for prompt in prompts:
    # 调用API
    response = teacher_model.generate(prompt)
    teacher_responses.append({
        "prompt": prompt,
        "response": response
    })

# 2. 用教师响应训练学生
for batch in create_dataloader(teacher_responses):
    prompts, responses = batch

    # 学生模型前向
    student_outputs = student_model(prompts)

    # 标准语言模型损失
    loss = cross_entropy_loss(student_outputs, responses)

    loss.backward()
    optimizer.step()

方法B:排序蒸馏(Ranking Distillation)

让学生学习教师对多个回答的偏好排序

class RankingDistillation:
    def __init__(self, teacher_model, student_model):
        self.teacher = teacher_model
        self.student = student_model

    def collect_ranked_data(self, prompts):
        dataset = []

        for prompt in prompts:
            # 生成多个候选回答
            candidates = []
            for _ in range(4):
                response = self.student.generate(prompt, temperature=0.8)
                candidates.append(response)

            # 教师评分(使用reward model或直接打分)
            scores = []
            for candidate in candidates:
                score = self.teacher.score(prompt, candidate)
                scores.append(score)

            # 排序
            ranked = sorted(
                zip(candidates, scores),
                key=lambda x: x[1],
                reverse=True
            )

            dataset.append({
                "prompt": prompt,
                "best": ranked[0][0],
                "worst": ranked[-1][0]
            })

        return dataset

    def train(self, dataset):
        for batch in dataset:
            prompt = batch["prompt"]
            best = batch["best"]
            worst = batch["worst"]

            # 学生模型的log概率
            logp_best = self.student.log_prob(prompt, best)
            logp_worst = self.student.log_prob(prompt, worst)

            # 排序损失(类似DPO)
            loss = -F.logsigmoid(logp_best - logp_worst).mean()

            loss.backward()
            optimizer.step()

优缺点

优点

  • ✅ 完全黑盒,无需访问教师内部
  • ✅ 可以利用API服务
  • ✅ 简单易实现

缺点

  • ❌ 信息损失大(只有文本,没有概率分布)
  • ❌ 需要大量生成数据
  • ❌ 效果通常不如白盒方法

蒸馏效果对比

实验设置

  • 教师:Llama 3 70B
  • 学生:Llama 3 8B
  • 任务:MMLU(通用知识问答)
方法准确率相对教师训练时间需要访问
学生原始62.3%-12.7%--
Logits蒸馏68.5%-6.5%2天Logits
特征蒸馏69.8%-5.2%3天内部特征
响应蒸馏65.7%-9.3%1.5天仅文本
教师模型75.0%---

观察

  • 特征蒸馏效果最好,但需要白盒访问
  • Logits蒸馏是性价比最高的方法
  • 响应蒸馏适合API场景,效果略差

第二部分:灾难性遗忘(Catastrophic Forgetting)

什么是灾难性遗忘?

定义:神经网络在学习新任务时,会急剧遗忘之前学习的任务知识。

类比:学生学习

传统学生:
  学数学 → 数学能力 ↑
  学物理 → 数学能力保持,物理能力 ↑
  学化学 → 数学、物理保持,化学能力 ↑

神经网络:
  学数学 → 数学能力 ↑
  学物理 → 数学能力 ↓↓,物理能力 ↑
  学化学 → 数学、物理能力 ↓↓↓,化学能力 ↑

实验:观察灾难性遗忘

实验设计

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# 任务A:情感分类(电影评论)
# 任务B:主题分类(新闻文章)

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(10000, 128)
        self.lstm = nn.LSTM(128, 256, batch_first=True)
        self.fc = nn.Linear(256, 2)

    def forward(self, x):
        emb = self.embedding(x)
        _, (hidden, _) = self.lstm(emb)
        return self.fc(hidden[-1])

# 评估函数
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)

    return correct / total

# 实验流程
model = SimpleModel()

# 阶段1:训练任务A
print("=== 训练任务A(情感分类)===")
train_on_task_a(model, task_a_data, epochs=5)
acc_a_after_a = evaluate(model, task_a_test)
print(f"任务A准确率: {acc_a_after_a:.2%}")

# 阶段2:训练任务B
print("\n=== 训练任务B(主题分类)===")
train_on_task_b(model, task_b_data, epochs=5)
acc_a_after_b = evaluate(model, task_a_test)  # 再次评估任务A
acc_b_after_b = evaluate(model, task_b_test)
print(f"任务A准确率: {acc_a_after_b:.2%} (下降 {acc_a_after_a - acc_a_after_b:.2%})")
print(f"任务B准确率: {acc_b_after_b:.2%}")

# 阶段3:训练任务C
print("\n=== 训练任务C(实体识别)===")
train_on_task_c(model, task_c_data, epochs=5)
acc_a_after_c = evaluate(model, task_a_test)
acc_b_after_c = evaluate(model, task_b_test)
acc_c_after_c = evaluate(model, task_c_test)
print(f"任务A准确率: {acc_a_after_c:.2%} (下降 {acc_a_after_a - acc_a_after_c:.2%})")
print(f"任务B准确率: {acc_b_after_c:.2%} (下降 {acc_b_after_b - acc_b_after_c:.2%})")
print(f"任务C准确率: {acc_c_after_c:.2%}")

典型结果

=== 训练任务A(情感分类)===
任务A准确率: 89.5%

=== 训练任务B(主题分类)===
任务A准确率: 67.2% (下降 22.3%) ⚠️
任务B准确率: 86.3%

=== 训练任务C(实体识别)===
任务A准确率: 52.1% (下降 37.4%) ⚠️⚠️
任务B准确率: 63.8% (下降 22.5%) ⚠️
任务C准确率: 84.7%

可视化

准确率
  ^
  |  任务A
90|  ●━━━━╮
  |        ╰━━━━╮
  |              ╰━━━━━━━━●  任务C
  |                   任务B
60|                   ●━━━━╮
  |                        ╰━━━━●
  |
30|
  |
  +─────────────────────────────────> 训练阶段
     训练A    训练B         训练C

为什么会发生灾难性遗忘?

原因1:参数重写

神经网络的参数是共享的

任务A学习:
  参数θ初始 → 参数θ_A(优化到任务A最优)

任务B学习:
  参数θ_A → 参数θ_B(优化到任务B最优)
  但θ_B可能对任务A很差!

数学解释

假设任务A和B的损失函数分别为 LA(θ)\mathcal{L}_A(\theta)LB(θ)\mathcal{L}_B(\theta)

  • 训练任务A后:θA=argminLA(θ)\theta_A = \arg\min \mathcal{L}_A(\theta)
  • 训练任务B后:θB=argminLB(θ)\theta_B = \arg\min \mathcal{L}_B(\theta)

问题:θB\theta_B 不保证 LA(θB)\mathcal{L}_A(\theta_B) 小!

原因2:梯度冲突

任务A和B的梯度可能相反

参数w的梯度:
  任务A: ∂L_A/∂w = +2.5  (希望增大w)
  任务B: ∂L_B/∂w = -3.0  (希望减小w)

结果:训练任务B时,w减小,损害任务A性能

原因3:分布偏移

新任务的数据分布不同

任务A(医疗):
  - 词汇:"症状""诊断""治疗"
  - 句式:正式、专业

任务B(社交媒体):
  - 词汇:"点赞""转发""吐槽"
  - 句式:口语、非正式

模型适应B后,A的特征激活模式被改变

遗忘的数学分析

Fisher信息矩阵

核心思想:某些参数对旧任务"更重要",不应该被大幅修改。

Fisher信息矩阵定义:

Fi=ExDA[(logp(yx;θA)θi)2]F_i = \mathbb{E}_{x \sim \mathcal{D}_A} \left[ \left( \frac{\partial \log p(y|x; \theta_A)}{\partial \theta_i} \right)^2 \right]

含义

  • FiF_i 大:参数 θi\theta_i 对任务A很重要(梯度大且稳定)
  • FiF_i 小:参数 θi\theta_i 对任务A不太重要(可以安全修改)

可视化

参数重要性:

θ₁: ████████████ F₁=12.5  (非常重要,不能改)
θ₂: ████████     F₂=8.0   (重要)
θ₃: ███          F₃=3.0   (一般)
θ₄: █            F₄=1.0   (不重要,可以改)

启发:训练新任务时,重要参数应该少改或不改


第三部分:回放机制(Experience Replay)

什么是回放?

核心思想:在学习新任务时,同时回放旧任务的数据,让模型"不忘初心"。

类比:学生复习

不使用回放:
  第1天:学数学(专注数学,100%时间)2天:学物理(专注物理,100%时间)← 忘记数学3天:学化学(专注化学,100%时间)← 忘记数学和物理

使用回放:
  第1天:学数学(100%时间)2天:学物理(70%时间) + 复习数学(30%时间)3天:学化学(60%时间) + 复习数学、物理(40%时间)

回放方法1:原始数据回放(Naive Replay)

原理

保存旧任务的训练数据,与新任务数据混合训练。

实现

class ExperienceReplayBuffer:
    def __init__(self, max_size=10000):
        self.buffer = []
        self.max_size = max_size

    def add_task_data(self, task_data):
        """
        添加新任务的数据到缓冲区

        Args:
            task_data: List of (input, label) tuples
        """
        # 如果缓冲区满了,随机替换旧数据
        for sample in task_data:
            if len(self.buffer) < self.max_size:
                self.buffer.append(sample)
            else:
                # 随机替换
                idx = random.randint(0, self.max_size - 1)
                self.buffer[idx] = sample

    def sample(self, batch_size):
        """随机采样"""
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))

# 使用示例
replay_buffer = ExperienceReplayBuffer(max_size=5000)

# 训练任务A
for epoch in range(5):
    for batch in task_a_dataloader:
        inputs, labels = batch

        # 训练
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 保存数据到回放缓冲区
        replay_buffer.add_task_data(list(zip(inputs, labels)))

print(f"缓冲区大小: {len(replay_buffer.buffer)}")

# 训练任务B(同时回放任务A)
for epoch in range(5):
    for batch in task_b_dataloader:
        new_inputs, new_labels = batch

        # 1. 训练新任务数据
        outputs = model(new_inputs)
        loss_new = criterion(outputs, new_labels)

        # 2. 回放旧任务数据
        if len(replay_buffer.buffer) > 0:
            replay_samples = replay_buffer.sample(batch_size=32)
            replay_inputs, replay_labels = zip(*replay_samples)
            replay_inputs = torch.stack(replay_inputs)
            replay_labels = torch.tensor(replay_labels)

            replay_outputs = model(replay_inputs)
            loss_replay = criterion(replay_outputs, replay_labels)
        else:
            loss_replay = 0

        # 3. 总损失
        loss = loss_new + 0.5 * loss_replay  # 可调整权重

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

效果对比

方法任务A (训练B后)任务B总平均
无回放67.2% (-22.3%)86.3%76.8%
回放 (50%)82.5% (-7.0%)84.1%83.3% ✅
回放 (100%)87.2% (-2.3%)82.8%85.0% ✅✅

观察

  • 回放显著减少遗忘
  • 回放比例越高,旧任务保留越好(但新任务可能略有下降)

优缺点

优点

  • ✅ 简单有效
  • ✅ 理论上最优(如果数据充足)

缺点

  • 存储开销大(需要保存原始数据)
  • 隐私问题(无法删除用户数据)
  • 数据不平衡(多任务时缓冲区有限)

回放方法2:生成式回放(Generative Replay)

原理

不保存原始数据,而是用生成模型"回忆"旧数据

流程

步骤1:训练任务A
  → 训练主模型(分类器)
  → 同时训练生成器G_A(学习生成任务A的数据)

步骤2:训练任务B
  → 用G_A生成"伪"任务A数据
  → 与真实任务B数据混合训练
  → 训练生成器G_B

实现

class GenerativeReplay:
    def __init__(self, main_model, generator):
        """
        Args:
            main_model: 主任务模型(如分类器)
            generator: 生成模型(如VAE或diffusion)
        """
        self.main_model = main_model
        self.generator = generator

    def train_task(self, new_data, prev_generators=None):
        """
        训练新任务

        Args:
            new_data: 新任务的真实数据
            prev_generators: 之前任务的生成器列表
        """
        optimizer_main = torch.optim.Adam(self.main_model.parameters())
        optimizer_gen = torch.optim.Adam(self.generator.parameters())

        for epoch in range(num_epochs):
            for batch in new_data:
                real_inputs, real_labels = batch

                # ===== 训练主模型 =====
                # 1. 新任务数据
                outputs = self.main_model(real_inputs)
                loss_new = criterion(outputs, real_labels)

                # 2. 生成旧任务数据(回放)
                if prev_generators:
                    loss_replay = 0
                    for old_gen in prev_generators:
                        # 生成伪数据
                        with torch.no_grad():
                            fake_inputs = old_gen.generate(batch_size=32)

                        # 用当前模型预测(目标是保持旧模型的预测)
                        outputs = self.main_model(fake_inputs)

                        # 用旧模型的预测作为伪标签
                        with torch.no_grad():
                            pseudo_labels = old_main_model(fake_inputs).argmax(dim=1)

                        loss_replay += criterion(outputs, pseudo_labels)

                    loss_main = loss_new + 0.5 * loss_replay
                else:
                    loss_main = loss_new

                # 更新主模型
                optimizer_main.zero_grad()
                loss_main.backward()
                optimizer_main.step()

                # ===== 训练生成器 =====
                # 让生成器学习生成当前任务的数据
                fake_data = self.generator.generate(batch_size=32)
                gen_loss = generator_loss(fake_data, real_inputs)

                optimizer_gen.zero_grad()
                gen_loss.backward()
                optimizer_gen.step()

        return self.generator  # 返回当前生成器,供后续任务使用

# 使用
replay = GenerativeReplay(
    main_model=classifier,
    generator=VAE()
)

# 训练任务A
gen_a = replay.train_task(task_a_data, prev_generators=None)

# 训练任务B(使用gen_a回放)
gen_b = replay.train_task(task_b_data, prev_generators=[gen_a])

# 训练任务C
gen_c = replay.train_task(task_c_data, prev_generators=[gen_a, gen_b])

优缺点

优点

  • 不需要存储原始数据(解决隐私问题)
  • 内存需求小(只存生成器参数)
  • 可扩展(生成器可压缩)

缺点

  • 生成质量影响效果(生成器不好则回放失效)
  • 训练成本高(需要额外训练生成器)
  • 不适合复杂数据(如高分辨率图像、长文本)

回放方法3:知识蒸馏回放(Distillation Replay)

原理

结合知识蒸馏和回放:用旧模型的软标签作为回放目标。

关键洞察

  • 不需要保存原始数据
  • 不需要训练生成器
  • 当前模型在旧任务上的预测作为回放信号

实现

class DistillationReplay:
    def __init__(self, model, temperature=2.0):
        self.model = model
        self.temperature = temperature
        self.old_model = None  # 保存旧模型

    def train_new_task(self, new_data, replay_data=None):
        """
        训练新任务,同时通过蒸馏回放旧知识

        Args:
            new_data: 新任务数据 (inputs, labels)
            replay_data: 回放数据(只需输入,不需要标签)
        """
        # 复制当前模型作为"旧模型"(冻结)
        if self.old_model is None:
            self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()

        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for new_batch in new_data:
                new_inputs, new_labels = new_batch

                # ===== 新任务损失 =====
                new_outputs = self.model(new_inputs)
                loss_new = F.cross_entropy(new_outputs, new_labels)

                # ===== 回放损失(蒸馏) =====
                if replay_data and len(replay_data) > 0:
                    # 采样回放数据
                    replay_inputs = random.sample(replay_data, k=32)
                    replay_inputs = torch.stack(replay_inputs)

                    # 当前模型的输出
                    curr_outputs = self.model(replay_inputs)

                    # 旧模型的输出(软标签)
                    with torch.no_grad():
                        old_outputs = self.old_model(replay_inputs)

                    # 蒸馏损失
                    curr_soft = F.log_softmax(
                        curr_outputs / self.temperature, dim=1
                    )
                    old_soft = F.softmax(
                        old_outputs / self.temperature, dim=1
                    )

                    loss_distill = F.kl_div(
                        curr_soft, old_soft, reduction='batchmean'
                    ) * (self.temperature ** 2)
                else:
                    loss_distill = 0

                # ===== 总损失 =====
                loss = loss_new + 0.5 * loss_distill

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # 更新旧模型
        self.old_model = copy.deepcopy(self.model)

# 使用
replay = DistillationReplay(model, temperature=3.0)

# 训练任务A
replay.train_new_task(
    new_data=task_a_data,
    replay_data=None  # 第一个任务无需回放
)

# 训练任务B(蒸馏回放任务A)
# 只需要任务A的输入(不需要标签)
task_a_inputs = [x for x, y in task_a_data]

replay.train_new_task(
    new_data=task_b_data,
    replay_data=task_a_inputs  # 回放任务A的输入
)

# 训练任务C
task_ab_inputs = task_a_inputs + [x for x, y in task_b_data]

replay.train_new_task(
    new_data=task_c_data,
    replay_data=task_ab_inputs  # 回放A和B
)

为什么有效?

旧模型的软标签包含了"如何解决旧任务"的知识

任务A:情感分类
  输入: "这部电影很棒"
  旧模型输出: [正面: 0.92, 负面: 0.08]

学习任务B时:
  在相同输入上,新模型应该输出相似的概率分布
  → 保持了任务A的知识

优缺点

优点

  • 不需要存储数据(只需少量输入样本)
  • 不需要训练生成器
  • 简单高效
  • 内存友好

缺点

  • ⚠️ 需要保存旧模型副本(但可以定期合并)
  • ⚠️ 回放数据需要覆盖旧任务的分布

回放方法对比

方法存储需求隐私友好效果复杂度适用场景
原始数据回放高(原始数据)❌ 低⭐⭐⭐⭐⭐数据可存储
生成式回放中(生成器)✅ 高⭐⭐⭐隐私敏感
蒸馏回放低(模型副本)✅ 高⭐⭐⭐⭐推荐

第四部分:蒸馏与回放的结合

为什么结合?

蒸馏和回放解决不同问题

技术解决的问题典型场景
知识蒸馏模型太大,需要压缩部署到边缘设备
回放机制持续学习时遗忘多任务增量学习

结合场景

场景:在线学习系统
  - 大模型(教师):在云端,持续学习新任务
  - 小模型(学生):在设备端,需要同步云端知识

挑战:
  1. 学生模型太小,无法直接学习所有任务
  2. 教师模型学习新任务时会遗忘

解决方案:蒸馏 + 回放

方法1:逐步蒸馏(Progressive Distillation)

原理

每学习一个新任务,就蒸馏一次

任务A → 训练Teacher_A → 蒸馏到Student_A
任务B → 训练Teacher_B(回放A)→ 蒸馏到Student_B(回放A)
任务C → 训练Teacher_C(回放A+B)→ 蒸馏到Student_C(回放A+B

实现

class ProgressiveDistillation:
    def __init__(self, teacher_model, student_model, temperature=3.0):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        self.replay_buffer = []

    def learn_task(self, task_data, task_id):
        """
        学习新任务,同时回放旧任务

        Args:
            task_data: 新任务数据
            task_id: 任务ID
        """
        print(f"\n=== 学习任务 {task_id} ===")

        # ===== 步骤1:训练教师模型(带回放) =====
        print("步骤1: 训练教师模型")
        self._train_teacher(task_data)

        # ===== 步骤2:蒸馏到学生模型(带回放) =====
        print("步骤2: 蒸馏到学生模型")
        self._distill_to_student(task_data)

        # ===== 步骤3:保存回放数据 =====
        self.replay_buffer.extend(
            random.sample(task_data, k=min(1000, len(task_data)))
        )
        print(f"回放缓冲区大小: {len(self.replay_buffer)}")

    def _train_teacher(self, new_data):
        """训练教师模型(带回放)"""
        optimizer = torch.optim.Adam(self.teacher.parameters(), lr=1e-5)

        for epoch in range(5):
            # 混合新数据和回放数据
            if self.replay_buffer:
                replay_sample = random.sample(
                    self.replay_buffer,
                    k=min(len(new_data), len(self.replay_buffer))
                )
                mixed_data = list(new_data) + replay_sample
            else:
                mixed_data = new_data

            random.shuffle(mixed_data)

            for inputs, labels in DataLoader(mixed_data, batch_size=32):
                outputs = self.teacher(inputs)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def _distill_to_student(self, new_data):
        """蒸馏到学生模型(带回放)"""
        optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-5)

        # 混合数据
        if self.replay_buffer:
            all_data = list(new_data) + self.replay_buffer
        else:
            all_data = new_data

        for epoch in range(3):
            for inputs, labels in DataLoader(all_data, batch_size=32):
                # 教师预测
                with torch.no_grad():
                    teacher_logits = self.teacher(inputs)

                # 学生预测
                student_logits = self.student(inputs)

                # 蒸馏损失
                loss = self._kd_loss(
                    student_logits, teacher_logits, labels
                )

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def _kd_loss(self, student_logits, teacher_logits, labels):
        """知识蒸馏损失"""
        # 软标签损失
        student_soft = F.log_softmax(
            student_logits / self.temperature, dim=1
        )
        teacher_soft = F.softmax(
            teacher_logits / self.temperature, dim=1
        )
        soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
        soft_loss *= (self.temperature ** 2)

        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, labels)

        # 组合
        return 0.7 * soft_loss + 0.3 * hard_loss

# 使用
teacher = LargeModel()
student = SmallModel()

progressive = ProgressiveDistillation(teacher, student, temperature=3.0)

# 逐步学习多个任务
for task_id, task_data in enumerate(all_tasks):
    progressive.learn_task(task_data, task_id)

    # 评估
    print(f"任务 {task_id} 完成后的性能:")
    for eval_id in range(task_id + 1):
        acc = evaluate(student, eval_tasks[eval_id])
        print(f"  任务 {eval_id}: {acc:.2%}")

效果

任务标准微调蒸馏(无回放)蒸馏+回放
任务A89%81%87% ✅
任务B62% (-27%)78% (-3%)84% ✅
任务C48% (-41%)76% (-5%)82% ✅
平均66.3%78.3%84.3% ✅✅

方法2:自蒸馏回放(Self-Distillation Replay)

原理

模型作为自己的教师

不需要单独的大模型教师
  ↓
模型A(任务A训练后)→ 作为教师
  ↓
模型B(学习任务B)← 从模型A蒸馏 + 学习新任务
  ↓
模型B → 作为教师
  ↓
模型C(学习任务C)← 从模型B蒸馏 + 学习新任务

实现

class SelfDistillationReplay:
    def __init__(self, model, temperature=2.0, alpha=0.5):
        self.model = model
        self.temperature = temperature
        self.alpha = alpha  # 蒸馏权重
        self.old_model = None
        self.replay_inputs = []

    def learn_task(self, new_data):
        """学习新任务,通过自蒸馏保持旧知识"""

        # 保存旧模型
        if self.old_model is not None:
            # 更新回放输入
            new_inputs = [x for x, _ in random.sample(new_data, k=100)]
            self.replay_inputs.extend(new_inputs)

        self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()

        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for inputs, labels in DataLoader(new_data, batch_size=32):
                # ===== 新任务损失 =====
                outputs = self.model(inputs)
                loss_new = F.cross_entropy(outputs, labels)

                # ===== 自蒸馏损失 =====
                if self.old_model and self.replay_inputs:
                    # 采样回放输入
                    replay_batch = random.sample(
                        self.replay_inputs,
                        k=min(32, len(self.replay_inputs))
                    )
                    replay_batch = torch.stack(replay_batch)

                    # 当前模型输出
                    curr_logits = self.model(replay_batch)

                    # 旧模型输出
                    with torch.no_grad():
                        old_logits = self.old_model(replay_batch)

                    # 蒸馏损失
                    loss_distill = self._compute_kd_loss(
                        curr_logits, old_logits
                    )
                else:
                    loss_distill = 0

                # ===== 总损失 =====
                loss = (1 - self.alpha) * loss_new + self.alpha * loss_distill

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def _compute_kd_loss(self, student_logits, teacher_logits):
        """计算KD损失"""
        student_soft = F.log_softmax(
            student_logits / self.temperature, dim=1
        )
        teacher_soft = F.softmax(
            teacher_logits / self.temperature, dim=1
        )
        return F.kl_div(
            student_soft, teacher_soft, reduction='batchmean'
        ) * (self.temperature ** 2)

# 使用
model = MyModel()
self_distill = SelfDistillationReplay(model, temperature=2.0, alpha=0.5)

# 持续学习
for task_data in all_tasks:
    self_distill.learn_task(task_data)

优势

  • 不需要单独的教师模型(节省内存)
  • 简单易实现
  • 适合资源受限场景

第五部分:实践建议与总结

选择指南

场景1:模型压缩

目标:将70B模型压缩到7B

推荐方案

1. Logits蒸馏(首选)
   - 温度: T=3-5
   - alpha: 0.7-0.9
   - 训练: 2-3 epochs

2. 如果可以访问内部:特征蒸馏
   - 选择关键层(每隔4层选一个)
   - 特征损失权重: 0.5

3. 如果只有API:响应蒸馏
   - 收集大量高质量生成数据
   - 使用排序蒸馏提高效果

场景2:持续学习

目标:模型不断学习新任务,不遗忘旧任务

推荐方案

任务数量少(2-5个):
  → 原始数据回放(效果最好)
  → 每个任务保存1000-5000样本

任务数量中等(5-20个):
  → 蒸馏回放(推荐)
  → 保存旧模型副本 + 少量输入样本

任务数量多(20+个):
  → 生成式回放 或 压缩回放
  → 定期合并模型

场景3:在线学习系统

目标:云端大模型 + 边缘小模型,持续更新

推荐方案

逐步蒸馏 + 回放:
  1. 云端:大模型学习新任务(带回放)
  2. 蒸馏:定期蒸馏到小模型
  3. 部署:推送小模型到边缘设备

周期:每周或每月一次

超参数建议

知识蒸馏

参数推荐值说明
温度 T2-5越大越软,信息越丰富
alpha0.7-0.9软标签权重
学习率1e-5 ~ 5e-5比正常训练小
Epochs2-5不要过拟合

回放机制

参数推荐值说明
缓冲区大小1000-10000/任务取决于内存
回放比例0.3-0.5旧数据占比
回放权重0.5-1.0回放损失的权重

监控指标

蒸馏训练

def monitor_distillation(teacher, student, val_data):
    metrics = {}

    # 1. 性能差距
    teacher_acc = evaluate(teacher, val_data)
    student_acc = evaluate(student, val_data)
    metrics['performance_gap'] = teacher_acc - student_acc
    print(f"性能差距: {metrics['performance_gap']:.2%}")
    # 期望: <10%

    # 2. 输出分布相似度
    kl_divs = []
    for inputs, _ in val_data:
        with torch.no_grad():
            t_logits = teacher(inputs)
            s_logits = student(inputs)
            kl = F.kl_div(
                F.log_softmax(s_logits, dim=1),
                F.softmax(t_logits, dim=1),
                reduction='batchmean'
            )
            kl_divs.append(kl.item())

    metrics['avg_kl'] = sum(kl_divs) / len(kl_divs)
    print(f"平均KL散度: {metrics['avg_kl']:.4f}")
    # 期望: <0.5

    return metrics

持续学习

def monitor_continual_learning(model, all_task_data, current_task):
    """监控灾难性遗忘"""
    print(f"\n=== 完成任务 {current_task} 后的评估 ===")

    accuracies = []
    for task_id in range(current_task + 1):
        acc = evaluate(model, all_task_data[task_id])
        accuracies.append(acc)
        print(f"任务 {task_id}: {acc:.2%}")

    # 平均准确率
    avg_acc = sum(accuracies) / len(accuracies)
    print(f"平均准确率: {avg_acc:.2%}")

    # 遗忘程度(旧任务性能下降)
    if current_task > 0:
        old_tasks_acc = sum(accuracies[:-1]) / (len(accuracies) - 1)
        print(f"旧任务平均: {old_tasks_acc:.2%}")

        # 与初始训练时对比
        forgetting = initial_accs[current_task-1] - old_tasks_acc
        print(f"遗忘程度: {forgetting:.2%}")
        # 期望: <10%

常见问题与解决

问题1:蒸馏效果差

症状:学生模型准确率比教师低很多(>15%)
原因:
  1. 温度太低 → 软标签不够软
  2. alpha太小 → 软标签权重不够
  3. 学生太小 → 容量不足

解决:
  1. 增大温度(T: 3→5)
  2. 增大alpha(0.7→0.9)
  3. 增大学生模型(如果可能)
  4. 使用特征蒸馏(如果是白盒)

问题2:持续学习仍然遗忘

症状:即使用了回放,旧任务性能仍下降>20%
原因:
  1. 回放数据太少
  2. 回放权重太小
  3. 任务差异太大(分布偏移严重)

解决:
  1. 增大回放缓冲区(1000→5000)
  2. 增大回放权重(0.3→0.7)
  3. 使用更多回放策略(混合多种方法)
  4. 降低学习率(减缓参数变化)

问题3:训练时间过长

症状:蒸馏或回放训练时间是正常的2倍以上
原因:
  1. 回放数据太多
  2. 教师模型推理慢
  3. 没有并行优化

解决:
  1. 减小回放比例(50%→30%)
  2. 缓存教师模型的输出(预计算)
  3. 使用混合精度训练(FP16)
  4. 批量蒸馏(一次蒸馏多个样本)

小结

核心要点

知识蒸馏:让小模型学习大模型的"思考过程"

  • Logits蒸馏:学习输出概率分布(最常用)
  • 特征蒸馏:学习中间层表示(效果更好)
  • 响应蒸馏:学习文本输出(黑盒场景)

灾难性遗忘:持续学习新任务时,模型会快速遗忘旧知识

  • 原因:参数共享、梯度冲突、分布偏移
  • 表现:旧任务性能急剧下降(20-40%)

回放机制:通过"复习"旧任务防止遗忘

  • 原始回放:保存真实数据(效果最好,但有隐私问题)
  • 生成式回放:用生成器重建数据(隐私友好)
  • 蒸馏回放:用旧模型的软标签(推荐,简单高效)

结合使用:蒸馏 + 回放 = 小而不忘的模型

  • 逐步蒸馏:每学习新任务就蒸馏一次
  • 自蒸馏回放:模型作为自己的教师

核心公式

知识蒸馏损失

LKD=αT2KL(softmax(zS/T)softmax(zT/T))+(1α)CE(zS,y)\mathcal{L}_{\text{KD}} = \alpha \cdot T^2 \cdot \text{KL}\left(\text{softmax}(z^S/T) \| \text{softmax}(z^T/T)\right) + (1-\alpha) \cdot \text{CE}(z^S, y)

蒸馏回放损失

Ltotal=Lnew+λKL(πcurr(xold)πold(xold))\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{new}} + \lambda \cdot \text{KL}(\pi_{\text{curr}}(x_{\text{old}}) \| \pi_{\text{old}}(x_{\text{old}}))

实践检查清单

蒸馏前

  • 确定教师和学生模型架构
  • 选择蒸馏类型(白盒/黑盒)
  • 准备蒸馏数据(教师的输出)
  • 设置超参数(T, alpha)

持续学习前

  • 评估基线性能(每个任务单独训练)
  • 选择回放策略(数据/生成/蒸馏)
  • 分配回放缓冲区大小
  • 设计评估指标(监控遗忘)

训练中

  • 监控性能差距(学生 vs 教师)
  • 监控旧任务性能(防止遗忘)
  • 调整回放比例(平衡新旧任务)

训练后

  • 全面评估所有任务
  • 计算平均遗忘率
  • 对比压缩率(模型大小 vs 性能)

知识蒸馏让模型更小,回放机制让模型不忘 —— 两者结合,打造高效、持续学习的AI系统!