基于 MNIST 数据集的图像分类:从基础到实践

416 阅读10分钟

基于 MNIST 数据集的图像分类:从基础到实践

一、引言

在深度学习中,图像分类是一项极具代表性的任务。传统编程在面对图像分类问题时往往显得力不从心,因为程序员难以定义详尽的规则来准确分类从未见过的图像。而深度学习则通过大量数据训练神经网络,使其能够自行学习并确定分类规则,从而在图像分类领域大放异彩。本文将以 MNIST 数据集为例,详细阐述如何构建一个简单的神经网络来实现手写体数字的图像分类,深入了解深度学习在图像分类任务中的应用流程。

二、MNIST 数据集探秘

MNIST 数据集在深度学习历史上占据着重要地位,它包含了 70,000 张手写体数字的灰度图像。其中,训练集有 60,000 张图像,测试集有 10,000 张图像。在处理这些图像数据时,我们不仅需要图像本身(通常用 X 表示),还需要对应的正确标签(通常用 Y 表示)。因此,MNIST 数据集可分为 x_train(训练图像)、y_train(训练图像标签)、x_test(测试图像)和 y_test(测试图像标签)四个部分。 在 Python 中,我们可以借助 torchvision 库轻松加载 MNIST 数据集。例如:

import torchvision
train_set = torchvision.datasets.MNIST("./data/", train=True, download=True)
valid_set = torchvision.datasets.MNIST("./data/", train=False, download=True)

通过上述代码,我们可以将 MNIST 数据集加载到内存中,并查看其相关信息。如 train_set 包含 60,000 个数据点,valid_set 包含 10,000 个数据点,且每个图像是一个二维数组,尺寸为 28x28。

为了进一步了解数据结构,我们可以取出训练集中的第一个图像和标签对进行观察:

x_0, y_0 = train_set[0]

此时,x_0 是 PIL.Image.Image 类型,y_0 是 int 类型。我们可以从标签 y_0 中确认对应的图像数字,例如 y_0 的值为 5,则表示该图像对应的手写数字是 5。

三、张量:深度学习的数据基石

张量是深度学习框架中的重要概念,它可以表示任意多维度的数组。以计算机屏幕像素为例,其通常是一个三维张量,维度分别对应宽度、高度和颜色通道。在处理图像数据时,我们需要将图像转换为张量,以便神经网络进行处理。

torchvision 提供了 ToTensor 类来实现这一转换。例如:

import torchvision.transforms.v2 as transforms
trans = transforms.Compose([transforms.ToTensor()])
x_0_tensor = trans(x_0)

转换后的张量具有许多有用的属性和方法。我们可以查看其数据类型:

x_0_tensor.dtype

其结果为 torch.float32。还可以查看张量的最小值和最大值,由于 ToTensor 类将 PIL 图像的整数范围 [0, 255] 转换为 [0.0, 1.0] 的浮点数范围,所以 x_0_tensor.min() 的结果为 tensor(0.)x_0_tensor.max() 的结果为 tensor(1.)

此外,我们还可以查看张量各维度的大小,对于图像张量,通常采用 C x H x W 的表示方式,即第一个维度是颜色通道,第二个是高度,第三个是宽度。由于 MNIST 图像是黑白的,所以只有 1 个颜色通道,高度和宽度都是 28 像素,x_0_tensor.size() 的结果为 torch.Size([1, 28, 28])。 默认情况下,张量由 CPU 处理,但我们可以将其转移到 GPU 上以加速计算。例如:

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_0_tensor.to(device)

这样,张量就可以在 GPU 上进行处理,提高计算效率。

四、数据准备:训练前的关键步骤

(一)数据变换

在将数据加载到神经网络之前,我们通常需要对数据进行一些变换操作。torchvision 中的 transforms 模块提供了一系列函数来实现数据变换。例如,我们之前使用的 ToTensor 就是一种数据变换,它将 PIL 图像转换为张量。

我们可以使用 Compose 函数将多个变换组合在一起。例如:

trans = transforms.Compose([transforms.ToTensor()])

然后将这个变换应用到数据集上:

train_set.transform = trans
valid_set.transform = trans

这样,在数据加载过程中,每个图像都会先经过 ToTensor 变换,然后再被用于训练或验证。

(二)数据加载器

数据加载器(DataLoader)在深度学习中起着重要作用,它定义了如何从数据集中抽取数据来训练模型。我们可以一次性将整个数据集展示给模型,但这样不仅需要大量计算资源,而且研究表明,使用较小的数据批次(batch)对模型训练更有效。

例如,我们可以设置批大小为 32:

from torch.utils.data import DataLoader
batch_size = 32
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

这里,train_loader 在加载训练数据时会进行洗牌操作,以增加数据的随机性,有助于模型更好地学习。而 valid_loader 则不需要洗牌,因为在验证阶段不需要学习,只是评估模型的性能。

五、构建神经网络模型

神经网络由多个层组成,每个层在数据传递过程中执行特定的数学运算。对于我们的手写体数字图像分类任务,我们构建一个简单的神经网络模型,它包含以下几个组件:

(一)Flatten 层

Flatten 层用于将多维数据展平为向量。例如,对于图像数据,其原始维度为 C x H x W,通过 Flatten 层可以将其转换为一维向量。在 PyTorch 中,可以使用 nn.Flatten() 来创建 Flatten 层。

(二)输入层

输入层是神经网络的第一层神经元,它将展平后的图像数据连接到模型的其余部分。我们使用 nn.Linear 层来创建输入层,该层是密集连接的,意味着其中的每个神经元及其权重都会影响下一层的每个神经元。在创建输入层时,需要指定输入的大小和神经元的数量。由于我们已经将图像展平,输入的大小就是通道数、垂直像素数和水平像素数的乘积,即 1 * 28 * 28。对于神经元数量,这里先选择 512 个,当然,这个数字可以根据实际情况进行调整。同时,我们还使用 nn.ReLU 作为激活函数,它可以帮助网络学习更复杂的模式。

(三)隐藏层

隐藏层是位于输入层和输出层之间的一层神经元。与输入层类似,我们使用 nn.Linear 层来创建隐藏层,并指定输入和输出的神经元数量。这里我们同样设置 512 个神经元,并使用 nn.ReLU 作为激活函数。

(四)输出层

输出层是神经网络的最后一层,它返回模型的最终预测结果。由于我们的任务是对 10 个手写体数字进行分类,所以输出层有 10 个神经元,每个神经元对应一个数字类别。这里不需要使用激活函数,因为在后续的损失计算中会有相应的处理。

将上述各层组合起来,我们可以创建一个 Sequential 模型:

import torch.nn as nn
layers = [
    nn.Flatten(),
    nn.Linear(1 * 28 * 28, 512),  # 输入层
    nn.ReLU(),
    nn.Linear(512, 512),  # 隐藏层
    nn.ReLU(),
    nn.Linear(512, 10)  # 输出层
]
model = nn.Sequential(*layers)

为了提高计算效率,我们还可以将模型转移到 GPU 上(如果可用):

model.to(device)

并且,在 PyTorch 2.0 中,我们可以使用 torch.compile 对模型进行编译,以进一步提升性能:

model = torch.compile(model)

六、模型训练与验证

(一)损失函数与优化器

在训练模型时,我们需要一个损失函数来评估模型的预测结果与真实标签之间的差异。对于多类别分类任务,我们通常使用 CrossEntropyLoss 作为损失函数。例如:

loss_function = nn.CrossEntropyLoss()

同时,我们还需要一个优化器来根据损失函数的结果调整模型的参数,使模型能够不断优化。这里我们选择 Adam 优化器:

from torch.optim import Adam
optimizer = Adam(model.parameters())

(二)计算准确率

除了损失函数,我们还通常会关注模型的准确率等指标。为了计算准确率,我们需要比较模型的预测结果与真实标签。具体来说,我们可以定义一个函数来计算每个批次的准确率:

def get_batch_accuracy(output, y, N):
    pred = output.argmax(dim=1, keepdim=True)
    correct = pred.eq(y.view_as(pred)).sum().item()
    return correct / N

其中,output 是模型的输出,y 是真实标签,N 是数据集的大小。

(三)训练函数与验证函数

接下来,我们可以定义训练函数和验证函数。训练函数主要包括以下步骤:

  1. 将模型设置为训练模式(model.train())。
  2. 遍历训练数据加载器,获取每个批次的图像数据 x 和标签数据 y,并将它们转移到 GPU(如果可用)上。
  3. 将数据输入模型,得到模型的输出 output
  4. 优化器梯度清零(optimizer.zero_grad())。
  5. 计算损失(batch_loss = loss_function(output, y))。
  6. 反向传播计算梯度(batch_loss.backward())。
  7. 更新模型参数(optimizer.step())。
  8. 累计损失和准确率。

示例代码如下:

def train():
    loss = 0
    accuracy = 0

    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        optimizer.zero_grad()
        batch_loss = loss_function(output, y)
        batch_loss.backward()
        optimizer.step()

        loss += batch_loss.item()
        accuracy += get_batch_accuracy(output, y, len(train_loader.dataset))
    print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

验证函数与训练函数类似,但不需要进行梯度计算和参数更新,主要步骤如下:

  1. 将模型设置为评估模式(model.eval())。
  2. 遍历验证数据加载器,获取每个批次的图像数据 x 和标签数据 y,并将它们转移到 GPU(如果可用)上。
  3. 将数据输入模型,得到模型的输出 output
  4. 计算损失并累计准确率。 示例代码如下:
def validate():
    loss = 0
    accuracy = 0

    model.eval()
    with torch.no_grad():
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            output = model(x)

            loss += loss_function(output, y).item()
            accuracy += get_batch_accuracy(output, y, len(valid_loader.dataset))
    print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

(四)训练循环

最后,我们可以通过多次循环训练和验证来观察模型的学习过程。例如,我们可以设置训练 5 个 epoch

epochs = 5
for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    train()
    validate()

在每个 epoch 中,模型都会对整个训练数据集进行一次完整的遍历,并在验证数据集上进行评估。通过观察损失和准确率的变化,我们可以了解模型的学习效果和性能表现。

七、总结与展望

通过以上步骤,我们成功构建了一个用于 MNIST 数据集手写体数字图像分类的神经网络模型,并对其进行了训练和验证。从结果来看,模型的准确率能够迅速接近 100%,表现出了较好的性能。

MNIST 数据集不仅在深度学习历史上有着重要的地位,而且还是一个非常好的基准和调试工具。当我们尝试新的机器学习架构或算法时,可以先在 MNIST 数据集上进行测试,以评估其可行性和有效性。如果在 MNIST 数据集上都无法取得较好的学习效果,那么在更复杂的图像和数据集上可能也难以成功。