训练分类器(Training a Classifier)

3 阅读9分钟

训练分类器(Training a Classifier)

至此,你已经学会了如何定义神经网络、计算损失值以及更新网络权重。

现在你可能会思考:

数据该如何处理?

通常,当你需要处理图像、文本、音频或视频数据时,可以使用标准的 Python 包将数据加载为 NumPy 数组,然后将该数组转换为 torch.*Tensor

  • 图像数据:可使用 Pillow、OpenCV 等包
  • 音频数据:可使用 scipy、librosa 等包
  • 文本数据:可使用原生 Python/Cython 加载,或 NLTK、SpaCy 等包

针对计算机视觉领域,我们专门开发了 torchvision 包,它包含了常用数据集(如 ImageNet、CIFAR10、MNIST 等)的数据加载器,以及图像数据转换器(torchvision.datasetstorch.utils.data.DataLoader)。

这极大地提升了开发效率,避免了编写重复的样板代码。

本教程将使用 CIFAR10 数据集,它包含以下类别:airplane(飞机)、automobile(汽车)、bird(鸟类)、cat(猫)、deer(鹿)、dog(狗)、frog(青蛙)、horse(马)、ship(船)、truck(卡车)。CIFAR-10 中的图像尺寸为 3x32x32,即 3 通道彩色图像,分辨率为 32x32 像素。

cifar10
cifar10

训练图像分类器

我们将按以下步骤进行:

  1. 使用 torchvision 加载并归一化 CIFAR10 训练集和测试集
  2. 定义一个卷积神经网络
  3. 定义损失函数
  4. 在训练集上训练网络
  5. 在测试集上测试网络

1. 加载并归一化 CIFAR10

使用 torchvision 加载 CIFAR10 非常简单:

import torch
import torchvision
import torchvision.transforms as transforms

torchvision 数据集的输出是范围在 [0, 1] 之间的 PILImage 图像,我们需要将其转换为归一化到 [-1, 1] 范围的张量。

注意
如果你在 Windows 或 MacOS 上运行本教程时遇到与多进程相关的 BrokenPipeError 或 RuntimeError,请尝试将 torch.utils.data.DataLoader()num_worker 参数设置为 0。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
数据集下载进度输出
  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 459k/170M [00:00<00:38, 4.44MB/s]
  3%|▎         | 4.69M/170M [00:00<00:06, 26.3MB/s]
  5%|▌         | 9.04M/170M [00:00<00:04, 34.1MB/s]
  8%|▊         | 14.3M/170M [00:00<00:03, 41.3MB/s]
 11%|█         | 18.4M/170M [00:00<00:03, 40.4MB/s]
 13%|█▎        | 22.7M/170M [00:00<00:03, 41.0MB/s]
 16%|█▌        | 26.8M/170M [00:00<00:03, 39.1MB/s]
 18%|█▊        | 31.1M/170M [00:00<00:03, 40.3MB/s]
 21%|██        | 35.2M/170M [00:00<00:03, 38.8MB/s]
 23%|██▎       | 39.5M/170M [00:01<00:03, 40.0MB/s]
 26%|██▌       | 43.5M/170M [00:01<00:03, 39.3MB/s]
 28%|██▊       | 47.6M/170M [00:01<00:03, 39.7MB/s]
 30%|███       | 51.6M/170M [00:01<00:03, 37.7MB/s]
 33%|███▎      | 55.8M/170M [00:01<00:02, 38.7MB/s]
 35%|███▍      | 59.7M/170M [00:01<00:02, 38.5MB/s]
 37%|███▋      | 63.5M/170M [00:01<00:02, 38.2MB/s]
 40%|███▉      | 67.4M/170M [00:01<00:02, 37.2MB/s]
 42%|████▏     | 71.2M/170M [00:01<00:02, 37.4MB/s]
 44%|████▍     | 75.1M/170M [00:01<00:02, 37.9MB/s]
 46%|████▋     | 78.9M/170M [00:02<00:02, 36.9MB/s]
 48%|████▊     | 82.7M/170M [00:02<00:02, 37.0MB/s]
 52%|█████▏    | 88.3M/170M [00:02<00:01, 42.7MB/s]
 56%|█████▋    | 96.3M/170M [00:02<00:01, 53.4MB/s]
 60%|██████    | 103M/170M [00:02<00:01, 56.7MB/s]
 65%|██████▍   | 110M/170M [00:02<00:00, 62.6MB/s]
 69%|██████▊   | 117M/170M [00:02<00:00, 63.6MB/s]
 74%|███████▍  | 126M/170M [00:02<00:00, 70.6MB/s]
 78%|███████▊  | 133M/170M [00:02<00:00, 71.6MB/s]
 82%|████████▏ | 140M/170M [00:02<00:00, 71.4MB/s]
 87%|████████▋ | 148M/170M [00:03<00:00, 70.8MB/s]
 91%|█████████ | 155M/170M [00:03<00:00, 72.2MB/s]
 96%|█████████▌| 163M/170M [00:03<00:00, 75.1MB/s]
100%|██████████| 170M/170M [00:03<00:00, 50.2MB/s]

为了直观展示,我们来显示一些训练图像:

import matplotlib.pyplot as plt
import numpy as np

# 图像显示函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 获取一些随机的训练图像
dataiter = iter(trainloader)
images, labels = next(dataiter)

# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
输出结果
cifar10 tutorial
car   cat   cat   car

2. 定义卷积神经网络

复制「神经网络」章节中的网络结构,并修改为接收 3 通道图像(原网络接收 1 通道图像):

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # 展平除批次维度外的所有维度
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

3. 定义损失函数和优化器

使用交叉熵损失(Classification Cross-Entropy)和带动量的随机梯度下降(SGD):

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练网络

这是最有趣的部分!我们只需遍历数据迭代器,将输入数据送入网络并执行优化:

for epoch in range(2):  # 多次遍历数据集

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入;data 是 [inputs, labels] 列表
        inputs, labels = data

        # 清零参数梯度
        optimizer.zero_grad()

        # 前向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每 2000 个小批次打印一次
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')
训练输出结果
[1,  2000] loss: 2.167
[1,  4000] loss: 1.817
[1,  6000] loss: 1.677
[1,  8000] loss: 1.566
[1, 10000] loss: 1.520
[1, 12000] loss: 1.451
[2,  2000] loss: 1.391
[2,  4000] loss: 1.375
[2,  6000] loss: 1.336
[2,  8000] loss: 1.316
[2, 10000] loss: 1.286
[2, 12000] loss: 1.291
Finished Training

快速保存训练好的模型:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

有关保存 PyTorch 模型的更多细节,请参考官方文档

5. 在测试集上测试网络

我们已经在训练集上遍历了 2 次,但需要验证网络是否真的学到了东西。

我们将通过预测神经网络输出的类别标签,并与真实标签对比来验证。如果预测正确,就将该样本加入正确预测列表。

首先,显示测试集中的一张图像:

dataiter = iter(testloader)
images, labels = next(dataiter)

# 打印图像
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
输出结果
cifar10 tutorial
GroundTruth:  cat   ship  ship  plane

然后加载保存的模型(注:此处保存和重新加载模型并非必需,仅为演示方法):

net = Net()
net.load_state_dict(torch.load(PATH, weights_only=True))
加载输出结果
<All keys matched successfully>

查看神经网络对这些示例的预测结果:

outputs = net(images)

输出是 10 个类别的能量值,某个类别的能量值越高,网络认为该图像属于该类别的概率越大。我们取能量值最高的类别索引:

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))
输出结果
Predicted:  dog   car   ship  plane

结果看起来不错!

接下来查看网络在整个测试集上的表现:

correct = 0
total = 0
# 不训练时,无需计算输出的梯度
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # 通过网络运行图像以获取输出
        outputs = net(images)
        # 取能量值最高的类别作为预测结果
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'网络在 10000 张测试图像上的准确率: {100 * correct // total} %')
输出结果
网络在 10000 张测试图像上的准确率: 54 %

这比随机猜测(10% 准确率,从 10 个类别中随机选择)要好得多,说明网络确实学到了东西。

查看各类别的准确率表现:

# 准备统计每个类别的预测结果
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# 无需计算梯度
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # 收集每个类别的正确预测数
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# 打印每个类别的准确率
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'类别: {classname:5s} 的准确率: {accuracy:.1f} %')
输出结果
类别: plane 的准确率: 53.1 %
类别: car   的准确率: 72.6 %
类别: bird  的准确率: 19.5 %
类别: cat   的准确率: 23.1 %
类别: deer  的准确率: 55.7 %
类别: dog   的准确率: 68.4 %
类别: frog  的准确率: 50.0 %
类别: horse 的准确率: 69.2 %
类别: ship  的准确率: 67.9 %
类别: truck 的准确率: 65.3 %

在 GPU 上训练

与将张量移至 GPU 类似,我们也可以将神经网络移至 GPU。

首先定义设备为第一个可用的 CUDA 设备(如果有 CUDA):

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 如果在 CUDA 机器上运行,此处应输出 CUDA 设备:
print(device)
输出结果
cuda:0

本节后续内容假设 device 是 CUDA 设备。

使用以下方法递归遍历所有模块,将参数和缓冲区转换为 CUDA 张量:

net.to(device)

注意,每次迭代都需要将输入和目标也发送到 GPU:

inputs, labels = data[0].to(device), data[1].to(device)

为什么没有看到相比 CPU 的巨大速度提升?因为你的网络规模太小了。

练习:尝试增加网络的宽度(第一个 nn.Conv2d 的第二个参数,以及第二个 nn.Conv2d 的第一个参数——这两个值需要相同),看看能获得多大的速度提升。

已达成目标

  1. 从宏观层面理解 PyTorch 的张量库和神经网络
  2. 训练一个小型神经网络完成图像分类任务

在多 GPU 上训练

如果想要利用所有 GPU 获得更大的速度提升,请参考可选:数据并行

下一步可以学习什么?

总结

  1. PyTorch 提供 torchvision 库简化视觉数据集加载,支持自动下载、预处理和批量加载,可轻松处理 CIFAR10 等标准数据集;
  2. 图像分类器训练核心流程:加载数据→定义卷积网络→设置损失函数/优化器→训练网络→测试评估,需注意将数据和模型同步到 GPU 以加速训练;
  3. 简单卷积网络在 CIFAR10 上可达到约 54% 的整体准确率,不同类别表现差异较大,可通过增加网络宽度/深度进一步提升性能。