在深度学习的旅程中,模型训练是核心环节之一。如果将模型比作一辆汽车,那么其开发过程就像一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择,以及一些辅助工具等。今天,我们将从数据读取开始,迈出模型训练的第一步。
PyTorch中的数据读取机制
PyTorch提供了一种非常方便的数据读取机制,即使用Dataset类与DataLoader类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。
Dataset类
Dataset类是一个抽象类,用来表示数据集。我们可以通过继承Dataset类来自定义数据集的格式、大小和其他属性,以便后续供DataLoader类直接使用。无论使用自定义的数据集,还是官方封装好的数据集,其本质都是继承了Dataset类。在继承Dataset类时,至少需要重写以下方法:
__init__():构造函数,用于自定义数据读取方法以及进行数据预处理;__len__():返回数据集大小;__getitem__():索引数据集中的某一个数据。
下面是一个简单的例子,展示如何使用Dataset类定义一个Tensor类型的数据集:
Python复制
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __len__(self):
return self.data_tensor.size(0)
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
DataLoader类
在实际项目中,如果数据量很大,考虑到内存有限、I/O速度等问题,在训练过程中不可能一次性将所有数据全部加载到内存中,也不能只用一个进程去加载。因此,DataLoader类应运而生,它是一个迭代器,可以根据参数batch_size的值生成一个batch的数据,节省内存的同时,还可以实现多进程、数据打乱等处理。
DataLoader类的调用方式如下:
Python复制
from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集,必须参数
batch_size=2, # 输出的batch大小
shuffle=True, # 数据是否打乱
num_workers=0) # 进程数,0表示只有主进程
# 以循环形式输出
for data, target in tensor_dataloader:
print(data, target)
# 输出一个batch
print('One batch tensor data: ', iter(tensor_dataloader).next())
其中,DataLoader的参数说明如下:
dataset:Dataset类型,输入的数据集,必须参数;batch_size:int类型,每个batch有多少个样本;shuffle:bool类型,在每个epoch开始时,是否对数据进行重新打乱;num_workers:int类型,加载数据的进程数,0意味着所有的数据都会被加载到主进程,默认为0。
Torchvision简介
PyTorch官方为我们提供了一些常用的图片数据集,如果需要读取这些数据集,无需自己实现,只需利用Torchvision即可。Torchvision是一个与PyTorch配合使用的Python包,它不仅提供了一些常用数据集,还提供了几个已经搭建好的经典网络模型,以及集成了一些图像数据处理方面的工具,主要供数据预处理阶段使用。简单来说,Torchvision库就是常用数据集+常见网络模型+常用图像处理方法。
Torchvision的安装方式非常简单,可以使用conda或pip进行安装:
bash复制
conda install torchvision -c pytorch
或者
bash复制
pip install torchvision
此外,Torchvision中默认使用的图像加载器是PIL,因此为了确保Torchvision正常运行,我们还需要安装一个Python的第三方图像处理库——Pillow库。Pillow提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像存储、图像显示、格式转换以及基本的图像处理操作等。安装命令如下:
bash复制
conda install pillow
或者
bash复制
pip install pillow
利用Torchvision读取数据
安装好Torchvision之后,我们可以利用它来读取数据。Torchvision库中的torchvision.datasets包提供了丰富的图像数据集的接口,常用的图像数据集,如MNIST、COCO等,都已封装好。
MNIST数据集简介
MNIST数据集是一个著名的手写数字数据集,因为上手简单,在深度学习领域,手写数字识别是一个很经典的学习入门样例。MNIST数据集是NIST数据集的一个子集,它包含了四个部分:
| 文件名称 | 内容 |
|---|---|
| 训练集图片 | 60,000张手写数字图片 |
| 训练集标签 | 60,000个标签 |
| 测试集图片 | 10,000张手写数字图片 |
| 测试集标签 | 10,000个标签 |
MNIST数据集是ubyte格式存储,我们可以通过Torchvision将其解析并加载到内存中。
数据读取
torchvision.datasets支持的所有数据集都内置了相应的数据集接口。以MNIST为例,我们可以用如下方式调用:
Python复制
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data', # 指定保存MNIST数据集的位置
train=True, # 是否加载训练集数据
transform=None, # 图像预处理操作
target_transform=None, # 图像标签预处理操作
download=True) # 是否下载数据集
其中,torchvision.datasets.MNIST是一个类,对它进行实例化即可返回一个MNIST数据集对象。构造函数包含以下参数:
root:字符串,指定保存MNIST数据集的位置。如果download为False,则从目标位置读取数据集;download:布尔类型,是否下载数据集。如果为True,则自动从网上下载数据集,存储到root指定的位置。如果指定位置已存在数据集文件,则不会重复下载;train:布尔类型,是否加载训练集数据。如果为True,则只加载训练数据;如果为False,则只加载测试数据集。注意:并不是所有的数据集都做了训练集和测试集的划分,这个参数并不一定是有效参数,具体需要参考官方接口说明文档;transform:用于对图像进行预处理操作,如数据增强、归一化、旋转或缩放等;target_transform:用于对图像标签进行预处理操作。
运行上述代码后,程序会从指定的网址下载MNIST数据集,然后进行解压缩等操作。如果你再次运行相同的代码,则不会再有下载的过程。
数据预览
完成数据读取后,我们得到的是一个封装好的mnist_dataset对象。如果想查看mnist_dataset中的具体内容,可以将其转化为列表:
Python复制
mnist_dataset_list = list(mnist_dataset)
print(mnist_dataset_list)
从运行结果中可以看出,转换后的数据集对象变成了一个元组列表,每个元组有两个元素,第一个元素是图像数据,第二个元素是图像的标签。图像数据是PIL.Image.Image类型,这种类型可以直接在Jupyter中显示出来。显示一条数据的代码如下:
Python复制
from IPython.display import display
display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])
运行结果如下图所示。可以看出,数据集mnist_dataset中的第一条数据是图片手写数字“7”,对应的标签是“7”。