从 CNN 的角度理解蒸馏学习

71 阅读5分钟

一、前言

在 DeepSeek-r1 火爆后,蒸馏学习也迎来了新的关注。蒸馏学习用于让学生模型从已经训练好的老师模型学习能力,通常学生的参数量会小于老师模型。DeepSeek-r1 开源的时候,就开源了不同大小的蒸馏模型,这些蒸馏模型都是从 671B 的 DeepSeek-r1 学习而来的。

蒸馏学习早先由于蒸馏 Bert 模型,而我们今天将以 CNN 为例,学习蒸馏学习算法。

二、蒸馏学习

2.1 什么是蒸馏学习

蒸馏学习有很多方式。比如输入 x,让老师模型输出 y,然后用 x、y 作为训练数据来训练学生模型。这种方式简单粗暴,无需对学生模型训练部分做任何修改。

而另一种方式则是让学生模型学习老师模型的输出 logits。这种方式也叫soft label,今天我们要实现的也是这种方式。

2.2 计算损失

蒸馏学习和普通的深度学习训练没有太大区别,只不过在原来的基础上添加了一个额外的损失。

我们以 MNIST 为例,这是一个多分类问题,我们会使用 cross entropy 作为损失。现在假设我们有一个老师模型 Teacher,我们要利用蒸馏学习让 Student 模型学习 Teacher 的能力。这里我们使用 soft label 方式,那么在要求 Student 的 cross entropy loss 低的情况下,我们还要求 Teacher(x)输出的 logits_t 和 Student(x) 输出的logits_s 相似。

在多分类中,logits 经过 softmax 可以看作是一个概率分布,那么我们要对比的其实是两个分布的相似性,这很容易就想到 KL 散度。而这就是蒸馏学习的关键所在。

下面的计算 KL 散度的伪代码:

teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
distill_loss = kl_div(student_log_probs, teacher_probs)

这个也被称为蒸馏损失。实际上我们只需要在训练的时候加上这个损失就能达到蒸馏学习的效果。

三、代码实现

下面我们来实际实现一下蒸馏学习,这里使用 MNIST 来测试。

3.1 老师和学生模型

MNIST 是一个手写数字识别的任务,我们可以用 CNN 作为我们的模型。为了区分学生模型和老师模型,这里老师模型使用一个参数量稍大的模型,而学生模型使用参数量较小的模型。在结构上,两个模型差别不大,最后的输出都是长度为 10 的向量 logits,具体代码如下:


# 构建 Teacher 和 Student 模型
class TeacherNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 1024)
        return self.classifier(x)


class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 12 * 12, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 16 * 12 * 12)
        return self.classifier(x)

3.2 数据准备

数据准备部分我们将用整个训练集训练老师模型,然后在训练集中使用 5% 来训练学生模型,当然还会配合学习好的老师模型进行蒸馏学习。至于测试则都使用完整的测试集测试。具体代码如下:

# 固定随机种子
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed()

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 准备数据:只取 5% 的训练数据用于 Student
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
subset_indices = indices[:int(0.05 * num_train)]  # 取前 5%

subset_sampler = SubsetRandomSampler(subset_indices)
subset_loader = DataLoader(train_dataset, batch_size=64, sampler=subset_sampler)

full_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

3.3 训练老师模型

老师模型的训练我们使用 cross entropy 作为损失函数:

teacher_model = TeacherNet().to(device)


# 训练 Teacher 模型(使用全部数据)
def train_teacher(model, train_loader, epochs=5):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Teacher] Epoch {epoch + 1} Loss: {total_loss:.4f}")
 
 
# Step 1: 先训练 Teacher 模型
print("Training Teacher Model...")
train_teacher(teacher_model, full_loader, epochs=5)

3.4 训练学生模型

训练学生模型和训练老师模型代码很像,但是我们要加上 KL 散度损失:

student_model = StudentNet().to(device)

# 用 5% 的数据训练 Student 模型 + 蒸馏
def train_student_with_distillation(student, teacher, train_loader, test_loader, epochs=10, temperature=3.0, alpha=0.5):
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=1e-3)
    ce_loss = nn.CrossEntropyLoss()
    kl_div = nn.KLDivLoss(reduction='batchmean')

    for epoch in range(epochs):
        student.train()
        total_loss = 0
        correct = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # 计算teacher输出的logits,这里teacher 不需要计算梯度
            with torch.no_grad():
                teacher_logits = teacher(images)
            # 计算student输出的logits
            student_logits = student(images)

            # Soft target loss (KL Divergence)
            teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
            student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
            distill_loss = kl_div(student_log_probs, teacher_probs)

            # Label loss
            label_loss = ce_loss(student_logits, labels)

            # 合并损失
            loss = alpha * (temperature ** 2) * distill_loss + (1 - alpha) * label_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred = student_logits.argmax(dim=1)
            correct += pred.eq(labels).sum().item()

        acc = correct / len(train_loader.dataset)
        print(f"[Student] Epoch {epoch + 1} Loss: {total_loss:.4f}, Train Acc: {acc:.4f}")
        

# Step 2: 使用 5% 数据 + 蒸馏训练 Student
print("\nTraining Student Model with Distillation...")
train_student_with_distillation(student_model, teacher_model, subset_loader, test_loader, epochs=10, temperature=3.0,
                                alpha=0.7)

这里使用 no_grad 计算 teacher 的 logits,因为在计算 KL 散度时,我们把他当做常数,而且我们不需要更新 teacher 模型。

然后使用和 teacher 一样的 cross entropy 损失,最后对两个损失加权和。这样我们可以在学习原始数据的同时,从老师模型中学到一些知识。

现在我们只需要先充分训练老师模型,然后用蒸馏学习训练学生模型即可。

四、总结

蒸馏学习是一种从其它模型学习知识的算法。虽然蒸馏学习好用,但是通常情况学生模型性能会不如老师模型,因此不是所有情况都需要采用蒸馏学习。

如果你需要在一个配置较低的设备上运行模型,且对模型准确率要求没有那么高。这个时候就可以考虑蒸馏学习。或者你要在没有足够数据但有现成模型的情况下,训练另一个模型,这个时候也可以选择蒸馏学习。