开发也能看懂的大模型:半监督学习-FixMatch

547 阅读9分钟

1. FixMatch 简介

FixMatch 是一种半监督学习算法,结合了以下两个核心思想:

  1. 伪标签(Pseudo-Labeling) :利用模型的 高置信度预测 为未标注样本生成伪标签。
  2. 一致性正则化(Consistency Regularization) :通过对同一输入施加不同增强(如旋转、裁剪),约束模型输出的一致性。

FixMatch 的目标是通过利用大量未标注数据,在少量标注数据的情况下显著提升模型性能。


2. FixMatch 的工作流程

  1. 数据增强

    • 对未标注数据分别施加 弱增强强增强
    • 弱增强:轻微的变换(如随机裁剪、水平翻转)。
    • 强增强:更激烈的变换(如 RandAugment,加入颜色抖动、旋转等)。
  2. 伪标签生成

    • 用模型预测弱增强数据的类别,并仅保留置信度高于阈值的伪标签(如大于 0.95 的预测)。
    • 伪标签赋值为预测的类别。
  3. 一致性损失

    • 对强增强数据,使用生成的伪标签作为目标计算损失(交叉熵)。
    • 损失函数鼓励模型在强增强样本上输出与伪标签一致的预测。
  4. 标注数据监督训练

    • 同时对标注数据计算标准的监督损失。
  5. 联合优化

    • 损失函数为标注数据损失和未标注数据一致性损失的加权和。

3. FixMatch 的案例实现:CIFAR-10 图像分类

数据集:

  • CIFAR-10:包含 10 类,每类 6000 张图片(32x32 分辨率)。
  • 使用 40 张标注数据,其他 59960 张作为未标注数据。

实现步骤:

以下是 Python + PyTorch 的 FixMatch 实现框架(关键代码片段)。

(1) 数据加载与增强
from torchvision import transforms, datasets

# 数据增强:弱增强和强增强
weak_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

strong_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])

# CIFAR-10 数据集
train_labeled = datasets.CIFAR10(root='./data', train=True, transform=weak_augment, download=True)
train_unlabeled = datasets.CIFAR10(root='./data', train=True, transform=weak_augment, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True)

(2) 模型设计

使用一个简单的卷积神经网络(CNN)模型。

import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

(3) FixMatch 训练逻辑
import torch
import torch.optim as optim

# 损失函数
ce_loss = nn.CrossEntropyLoss()
threshold = 0.95  # 高置信度伪标签阈值

# FixMatch 训练步骤
def train_fixmatch(model, labeled_loader, unlabeled_loader, optimizer, device):
    model.train()
    for (x_labeled, y_labeled), (x_unlabeled, _) in zip(labeled_loader, unlabeled_loader):
        # 转移到设备
        x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        x_unlabeled = x_unlabeled.to(device)
        
        # 弱增强预测伪标签
        with torch.no_grad():
            pseudo_labels = torch.softmax(model(x_unlabeled), dim=1)
            max_probs, targets = torch.max(pseudo_labels, dim=1)
            mask = max_probs.ge(threshold)  # 置信度过滤

        # 强增强样本
        x_unlabeled_strong = strong_augment(x_unlabeled)

        # 计算损失
        labeled_loss = ce_loss(model(x_labeled), y_labeled)
        if mask.sum() > 0:  # 有伪标签通过过滤
            unlabeled_loss = ce_loss(model(x_unlabeled_strong[mask]), targets[mask])
        else:
            unlabeled_loss = torch.tensor(0.0).to(device)

        # 总损失
        loss = labeled_loss + unlabeled_loss

        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

4. FixMatch 性能分析

训练结果:

  • CIFAR-10 上 40 张标注数据的表现

    • 标注数据+FixMatch:精度可达 94.8%。
    • 仅有监督学习:精度低于 70%。

FixMatch 的效果:

  • 在极低标注数据的情况下,未标注数据的有效利用显著提升了模型性能。
  • 一致性正则化限制了模型的过拟合,增强了泛化能力。

5. 优势与局限性

优势:

  1. 低标注需求:用少量标注数据即可获得接近全监督的性能。
  2. 简单高效:无需复杂模型结构,仅需数据增强和置信度筛选。
  3. 通用性强:适用于图像分类、语音识别、文本分类等任务。

局限性:

  1. 依赖未标注数据的分布一致性:未标注数据需与标注数据分布相似。
  2. 伪标签的依赖:初始模型性能较差时,伪标签可能不够准确。
  3. 数据增强的敏感性:强增强策略对性能有较大影响,需合理设计。

6. 总结

FixMatch 通过结合伪标签和一致性正则化,在少量标注数据下充分挖掘未标注数据的潜力,展现了强大的性能。

应用场景

  • 标注成本高但未标注数据丰富的任务(如医学影像、自动驾驶、工业检测)。
  • 数据资源有限的新兴领域(如遥感、科学研究)。

思考题1: 为什么一致性方法采用旋转、裁剪等方法,就能实现半监督学习,提高模型精度

1. 数据增强模拟了“多样性”

现实中的数据千变万化,比如拍摄角度不同(旋转)、裁剪到不同区域、光照变化等等。

  • 数据增强(如旋转、裁剪)人为制造了这些“变化”,相当于让模型提前适应多样化的情况。
  • 如果模型对同一物体的不同表现形式(如一张旋转的猫图片)都能做出正确预测,它在现实场景中的表现也会更稳健。

2. 一致性正则化的核心:让模型“自我校准”

一致性正则化的思想是: “无论输入如何变化(旋转、裁剪等增强),模型的输出应该尽量保持一致。”

通俗地说,这就像训练模型去认出一个人,不管他是站着还是坐着、近一点还是远一点,模型都应该认出是同一个人。

  • 对模型的约束:当同一张图片(如一辆车)被裁剪或者旋转时,模型输出的类别应该是相同的(仍然是车)。
  • 自我校准:通过对这些增强数据计算一致性损失,模型逐步调整自己,避免因为图片的变化而“犯糊涂”。

3. 未标注数据的潜力被挖掘

在半监督学习中,未标注数据非常丰富,但直接用它们来训练模型,可能不可靠,因为没有明确标签。

一致性方法通过增强和模型预测:

  • 模型自己生成伪标签:比如弱增强的车图,模型预测为“车”。
  • 强增强来校验伪标签:旋转后的车图,模型的输出仍然应该是“车”。

这种方式用未标注数据逼着模型“自己教学”,即使没有人工标注,也能从中学到更多知识。

思考题2: 伪标签法是所有数据都打标签并可用吗

1. 伪标签法的基本流程

  1. 已标注数据先教会模型:用少量的标注数据训练一个初步模型,就像一个学生已经学会了基础知识。

  2. 用模型标记未标注数据:把模型当成“老师”,让它对未标注数据进行预测,生成伪标签。

    • 比如:一张未标注的猫图片,模型说“这应该是只猫”。
  3. 用伪标签再训练模型:将这些伪标签和未标注数据当成新数据,再次用于训练模型,帮助模型进一步提升性能。


2. 直观类比:学生自学

想象一个学生正在学习数学,他的老师只教了他基础的加法运算(标注数据)。现在有一堆新的练习题(未标注数据),但没有答案。

  • 学生试着自己做题:学生根据自己已经掌握的知识,给每道题填上答案(伪标签)。
  • 老师检查答案,改正错误:学生会根据自己的答案继续复习。虽然一开始答案可能有错,但随着复习和训练,他的水平会不断提升。

伪标签法的过程就像学生在未标注数据上自学,同时通过模型自身的改进逐渐提升学习效果。

3. 伪标签法的一个关键:高置信度伪标签

伪标签法不会直接使用所有的预测结果,而是只用模型非常自信的预测
比如:

  • 如果模型说“这张图片99%是猫”,那这个伪标签就很可能是正确的。
  • 如果模型说“这张图片可能是猫,也可能是狗(50%/50%)”,这样的伪标签可能会引入错误,反而影响训练。

所以,伪标签法通常会设置一个“置信度阈值”(比如 95%),只用高置信度的伪标签。

附录:完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import random

# 设置随机种子
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# 超参数
BATCH_SIZE = 64
LEARNING_RATE = 0.03
EPOCHS = 50
CONFIDENCE_THRESHOLD = 0.95
NUM_LABELED = 40  # 标注数据数量

# 数据增强
weak_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

strong_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])

# CIFAR-10 数据集加载
def get_cifar10_datasets(num_labeled):
    transform = transforms.ToTensor()
    full_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
    
    # 随机选择少量标注数据
    indices = list(range(len(full_dataset)))
    random.shuffle(indices)
    labeled_indices = indices[:num_labeled]
    unlabeled_indices = indices[num_labeled:]
    
    labeled_dataset = Subset(full_dataset, labeled_indices)
    unlabeled_dataset = Subset(full_dataset, unlabeled_indices)
    
    return labeled_dataset, unlabeled_dataset, test_dataset

# 加载数据
labeled_dataset, unlabeled_dataset, test_dataset = get_cifar10_datasets(NUM_LABELED)
labeled_loader = DataLoader(labeled_dataset, batch_size=BATCH_SIZE, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 模型
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.MaxPool2d(2)(x)
        x = nn.ReLU()(self.conv2(x))
        x = nn.MaxPool2d(2)(x)
        x = x.view(-1, 64 * 8 * 8)
        x = nn.ReLU()(self.fc1(x))
        return self.fc2(x)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleCNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

# FixMatch 训练函数
def train_fixmatch(epoch):
    model.train()
    total_loss = 0.0
    total_labeled_loss = 0.0
    total_unlabeled_loss = 0.0
    
    labeled_iter = iter(labeled_loader)
    unlabeled_iter = iter(unlabeled_loader)
    
    for batch_idx in range(len(unlabeled_loader)):
        try:
            labeled_data, labeled_target = next(labeled_iter)
        except StopIteration:
            labeled_iter = iter(labeled_loader)
            labeled_data, labeled_target = next(labeled_iter)
        
        try:
            unlabeled_data, _ = next(unlabeled_iter)
        except StopIteration:
            unlabeled_iter = iter(unlabeled_loader)
            unlabeled_data, _ = next(unlabeled_iter)
        
        # 移动到设备
        labeled_data, labeled_target = labeled_data.to(device), labeled_target.to(device)
        unlabeled_data = unlabeled_data.to(device)
        
        # 弱增强的伪标签
        with torch.no_grad():
            pseudo_labels = torch.softmax(model(unlabeled_data), dim=1)
            max_probs, targets = torch.max(pseudo_labels, dim=1)
            mask = max_probs.ge(CONFIDENCE_THRESHOLD).float()
        
        # 强增强后的未标注样本
        unlabeled_data_strong = torch.stack([strong_augment(img.cpu()).to(device) for img in unlabeled_data])
        
        # 计算损失
        labeled_loss = criterion(model(labeled_data), labeled_target)
        if mask.sum() > 0:
            unlabeled_loss = criterion(model(unlabeled_data_strong[mask.bool()]), targets[mask.bool()])
        else:
            unlabeled_loss = torch.tensor(0.0).to(device)
        
        loss = labeled_loss + unlabeled_loss
        
        # 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_labeled_loss += labeled_loss.item()
        total_unlabeled_loss += unlabeled_loss.item()
    
    print(f"Epoch {epoch}: Total Loss: {total_loss:.4f}, Labeled Loss: {total_labeled_loss:.4f}, Unlabeled Loss: {total_unlabeled_loss:.4f}")

# 测试函数
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

# 训练与测试
for epoch in range(1, EPOCHS + 1):
    train_fixmatch(epoch)
    test()