PyTorch模型训练第一步:Torchvision数据读取与处理

264 阅读6分钟

在深度学习的旅程中,模型训练是核心环节之一。如果将模型比作一辆汽车,那么其开发过程就像一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择,以及一些辅助工具等。今天,我们将从数据读取开始,迈出模型训练的第一步。

PyTorch中的数据读取机制

PyTorch提供了一种非常方便的数据读取机制,即使用Dataset类与DataLoader类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。

Dataset类

Dataset类是一个抽象类,用来表示数据集。我们可以通过继承Dataset类来自定义数据集的格式、大小和其他属性,以便后续供DataLoader类直接使用。无论使用自定义的数据集,还是官方封装好的数据集,其本质都是继承了Dataset类。在继承Dataset类时,至少需要重写以下方法:

  • __init__():构造函数,用于自定义数据读取方法以及进行数据预处理;
  • __len__():返回数据集大小;
  • __getitem__():索引数据集中的某一个数据。

下面是一个简单的例子,展示如何使用Dataset类定义一个Tensor类型的数据集:

Python复制

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    
    def __len__(self):
        return self.data_tensor.size(0)
    
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,))  # 标签是0或1

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print('Dataset size:', len(my_dataset))

# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])

DataLoader类

在实际项目中,如果数据量很大,考虑到内存有限、I/O速度等问题,在训练过程中不可能一次性将所有数据全部加载到内存中,也不能只用一个进程去加载。因此,DataLoader类应运而生,它是一个迭代器,可以根据参数batch_size的值生成一个batch的数据,节省内存的同时,还可以实现多进程、数据打乱等处理。

DataLoader类的调用方式如下:

Python复制

from torch.utils.data import DataLoader

tensor_dataloader = DataLoader(dataset=my_dataset,  # 传入的数据集,必须参数
                               batch_size=2,        # 输出的batch大小
                               shuffle=True,        # 数据是否打乱
                               num_workers=0)       # 进程数,0表示只有主进程

# 以循环形式输出
for data, target in tensor_dataloader:
    print(data, target)

# 输出一个batch
print('One batch tensor data: ', iter(tensor_dataloader).next())

其中,DataLoader的参数说明如下:

  • datasetDataset类型,输入的数据集,必须参数;
  • batch_sizeint类型,每个batch有多少个样本;
  • shufflebool类型,在每个epoch开始时,是否对数据进行重新打乱;
  • num_workersint类型,加载数据的进程数,0意味着所有的数据都会被加载到主进程,默认为0。

Torchvision简介

PyTorch官方为我们提供了一些常用的图片数据集,如果需要读取这些数据集,无需自己实现,只需利用Torchvision即可。Torchvision是一个与PyTorch配合使用的Python包,它不仅提供了一些常用数据集,还提供了几个已经搭建好的经典网络模型,以及集成了一些图像数据处理方面的工具,主要供数据预处理阶段使用。简单来说,Torchvision库就是常用数据集+常见网络模型+常用图像处理方法

Torchvision的安装方式非常简单,可以使用condapip进行安装:

bash复制

conda install torchvision -c pytorch

或者

bash复制

pip install torchvision

此外,Torchvision中默认使用的图像加载器是PIL,因此为了确保Torchvision正常运行,我们还需要安装一个Python的第三方图像处理库——Pillow库。Pillow提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像存储、图像显示、格式转换以及基本的图像处理操作等。安装命令如下:

bash复制

conda install pillow

或者

bash复制

pip install pillow

利用Torchvision读取数据

安装好Torchvision之后,我们可以利用它来读取数据。Torchvision库中的torchvision.datasets包提供了丰富的图像数据集的接口,常用的图像数据集,如MNIST、COCO等,都已封装好。

MNIST数据集简介

MNIST数据集是一个著名的手写数字数据集,因为上手简单,在深度学习领域,手写数字识别是一个很经典的学习入门样例。MNIST数据集是NIST数据集的一个子集,它包含了四个部分:

文件名称内容
训练集图片60,000张手写数字图片
训练集标签60,000个标签
测试集图片10,000张手写数字图片
测试集标签10,000个标签

MNIST数据集是ubyte格式存储,我们可以通过Torchvision将其解析并加载到内存中。

数据读取

torchvision.datasets支持的所有数据集都内置了相应的数据集接口。以MNIST为例,我们可以用如下方式调用:

Python复制

import torchvision

mnist_dataset = torchvision.datasets.MNIST(root='./data',  # 指定保存MNIST数据集的位置
                                           train=True,     # 是否加载训练集数据
                                           transform=None, # 图像预处理操作
                                           target_transform=None,  # 图像标签预处理操作
                                           download=True)  # 是否下载数据集

其中,torchvision.datasets.MNIST是一个类,对它进行实例化即可返回一个MNIST数据集对象。构造函数包含以下参数:

  • root:字符串,指定保存MNIST数据集的位置。如果downloadFalse,则从目标位置读取数据集;
  • download:布尔类型,是否下载数据集。如果为True,则自动从网上下载数据集,存储到root指定的位置。如果指定位置已存在数据集文件,则不会重复下载;
  • train:布尔类型,是否加载训练集数据。如果为True,则只加载训练数据;如果为False,则只加载测试数据集。注意:并不是所有的数据集都做了训练集和测试集的划分,这个参数并不一定是有效参数,具体需要参考官方接口说明文档;
  • transform:用于对图像进行预处理操作,如数据增强、归一化、旋转或缩放等;
  • target_transform:用于对图像标签进行预处理操作。

运行上述代码后,程序会从指定的网址下载MNIST数据集,然后进行解压缩等操作。如果你再次运行相同的代码,则不会再有下载的过程。

数据预览

完成数据读取后,我们得到的是一个封装好的mnist_dataset对象。如果想查看mnist_dataset中的具体内容,可以将其转化为列表:

Python复制

mnist_dataset_list = list(mnist_dataset)
print(mnist_dataset_list)

从运行结果中可以看出,转换后的数据集对象变成了一个元组列表,每个元组有两个元素,第一个元素是图像数据,第二个元素是图像的标签。图像数据是PIL.Image.Image类型,这种类型可以直接在Jupyter中显示出来。显示一条数据的代码如下:

Python复制

from IPython.display import display

display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])

运行结果如下图所示。可以看出,数据集mnist_dataset中的第一条数据是图片手写数字“7”,对应的标签是“7”。