PyTorch数据加载和处理

383 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第2天,点击查看活动详情

前言

深度神经网络训练的大多数情况下,我们需要使用三组数据:训练、验证和测试。我们使用训练数据集来训练模型;验证数据集用于监测模型在训练期间的性能;使用测试数据集对模型进行最终评估。测试数据集的真实值通常对模型而言通常是不可知的。我们需要至少一个训练数据集和一个验证数据集才能构建和训练模型。因此当我们只有一个数据集时,我们可以将数据集分成两组或三组。每个数据集都由输入值和目标值(真实值或标签)组成,通常使用 xX 表示输入,用 yY 表示目标,通过添加后缀 trainvaltest 来区分训练、验证和测试数据集。

在本节中,我们将学习 PyTorch 数据加载和处理数据工具,以使用 PyTorch 库来处理数据集。 

加载数据集

PyTorch torchvision 库提供了多个常用的数据集,我们首先介绍如何从 torchvision 加载 MNIST 数据集。

1. 首先,加载 MNIST 训练数据集:

from torchvision import datasets
path_data="./data"
train_data=datasets.MNIST(path_data, train=True, download=True)

2. 然后,提取训练数据集的输入数据和目标标签:

x_train, y_train=train_data.data,train_data.targets
print(x_train.shape)
print(y_train.shape)

3. 接下来,加载 MNIST 测试数据集:

val_data=datasets.MNIST(path2data, train=False, download=True)

4. 然后,提取测试数据集的输入数据和目标标签:

x_val,y_val=val_data.data, val_data.targets
print(x_val.shape)
print(y_val.shape)

5. 之后,为张量添加一个新维度:

if len(x_train.shape)==3:
    x_train=x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:
    x_val=x_val.unsqueeze(1)
print(x_val.shape)

6. 接下来我们展示一些示例图像,导入所需的包:

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np

7. 然后,定义一个辅助函数来将张量显示为图像:

def show(img):
    npimg = img.numpy()
    npimg_tr=np.transpose(npimg, (1,2,0))
    plt.imshow(npimg_tr,interpolation='nearest')
  1. 接下来,创建一个图像网格并显示这些图像:
x_grid=utils.make_grid(x_train[:40], nrow=8, padding=2)
print(x_grid.shape)
show(x_grid)
plt.show()

结果如下图所示:

Figure_1.png  

数据转换

图像转换(也称为数据增强)是一种用于提高模型性能的有效技术,torchvision 通过 transforms 类提供常见的图像变换。

1. 定义一个 transforms 类,以便在 MNIST 数据集上应用图像变换:

from torchvision import transforms
train_data=datasets.MNIST(path_data, train=True, download=True)
data_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=1),
        transforms.RandomVerticalFlip(p=1),
        transforms.ToTensor(),
])

2. 对来自 MNIST 数据集的图像应用转换:

img = train_data[0][0]
img_tr=data_transform(img)
img_tr_np=img_tr.numpy()
plt.subplot(1,2,1)
plt.imshow(img,cmap="gray")
plt.title("original")
plt.subplot(1,2,2)
plt.imshow(img_tr_np[0],cmap="gray");
plt.title("transformed")
plt.show()

结果如下图所示:

Figure_2.png

3.transformer 函数传递给数据集类:

data_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(1),
        transforms.RandomVerticalFlip(1),
        transforms.ToTensor(),
])
train_data=datasets.MNIST(path2data, train=True, download=True, transform=data_transform )

将张量封装到到数据集中

如果数据为张量形式,我们可以使用 TensorDataset 类将它们封装为 PyTorch 数据集,使训练期间更容易迭代数据。

1. 通过封装 x_trainy_train 创建一个 PyTorch 数据集:

from torch.utils.data import TensorDataset
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)

创建数据加载器

为了在训练期间迭代数据,我们可以使用 DataLoader 类创建一个数据加载器。

1. 为训练和验证数据集创建数据加载器:

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=8)
val_dl = DataLoader(val_ds, batch_size=8)

相关链接

PyTorch张量操作详解