PyTorch数据加载器的快速入门指南

63 阅读3分钟

PyTorch开箱即有强大的数据加载功能。但是,强大的能力伴随着巨大的责任,这使得PyTorch的数据加载成为一个相当高级的话题。

学习高级课题的最好方法之一就是从快乐的路径开始。然后在你发现你需要的时候再增加复杂性。让我们通过一个快速入门的例子。

什么是PyTorch DataLoader?

PyTorchDataLoader 类为你提供了一个可迭代的Dataset 。它很有用,因为它可以并行加载数据,自动洗牌和批处理单个样本,所有这些都是开箱即用。这就为你设置了一个非常简单的训练循环。

PyTorch数据集

但是要创建一个DataLoader ,你必须先有一个Dataset ,这个类负责将样本实际读入内存。当你实现一个DataLoaderDataset ,几乎所有有趣的逻辑都会在这里进行。

有两种风格的Dataset 类,地图式和可迭代式。地图风格的Datasets ,更常见,也更直接,所以我们将专注于它们。

要创建一个地图风格的Dataset 类,你需要实现两个方法:__getitem__()__len__()__len__() 方法返回数据集中的样本总数,__getitem__() 方法获取一个索引并返回该索引处的样本。

PyTorch的Dataset 对象非常灵活--它们可以返回你想要的任何类型的tensor(s)。但是有监督的训练数据集通常应该返回一个输入张量和一个标签。为了说明问题,让我们创建一个数据集,输入张量是一个3×3的矩阵,索引在对角线上。标签将是索引。

它应该看起来像这样。

dataset[3]

# Expected result
# {'x': array([[3., 0., 0.],
#         [0., 3., 0.],
#         [0., 0., 3.]]),
#  'y': 3}

记住,我们所要实现的是__getitem__()__len__()

from typing import Dict, Union

import numpy as np
import torch

class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, size: int):
        self.size = size

    def __len__(self) -> int:
        return self.size

    def __getitem__(self, index: int) -> Dict[str, Union[int, np.ndarray]]:
        return dict(
            x=np.eye(3) * index,
            y=index,
        )

非常简单。我们可以实例化这个类并开始访问各个样本。

dataset = ToyDataset(10)
dataset[3]

# Expected result
# {'x': array([[3., 0., 0.],
#         [0., 3., 0.],
#         [0., 0., 3.]]),
#  'y': 3}

如果碰巧在处理图像数据,__getitem__() 可能是一个放置你的TorchVision变换的好地方。

在这一点上,一个样本是一个dict ,其中"x" 是一个矩阵,形状是(3, 3)"y" 是一个Python整数。但是我们想要的是成批的数据。"x" 应该是一个PyTorch张量,形状为(batch_size, 3, 3)"y" 应该是一个张量,形状为batch_size 。这就是DataLoader 回来的地方。

PyTorch数据加载器

为了迭代成批的样本,将你的Dataset 对象传递给一个DataLoader

torch.manual_seed(1234)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=3,
    shuffle=True,
    num_workers=2,
)
for batch in loader:
    print(batch["x"].shape, batch["y"])

# Expected result
# torch.Size([3, 3, 3]) tensor([2, 1, 3])
# torch.Size([3, 3, 3]) tensor([6, 7, 9])
# torch.Size([3, 3, 3]) tensor([5, 4, 8])
# torch.Size([1, 3, 3]) tensor([0])

请注意这里发生的几件事。

  • NumPy数组和Python整数都被转换为PyTorch的张量。
  • 尽管我们在ToyDataset 中获取单个样本,但DataLoader 会自动为我们批量处理这些样本,并按照我们要求的批量大小进行处理。即使单个样本是在dict结构中,这也是有效的。如果你返回图元,这也是有效的。
  • 样本是随机洗牌的。我们通过设置torch.manual_seed(1234) 来保持可重复性。
  • 样本是跨进程并行读取的。事实上,如果你在Jupyter笔记本中运行这段代码会失败。为了让它工作,你需要把它放在Python脚本中的if __name__ == "__main__": 检查之下。

还有一件事,我在这个例子中没有做,但你应该注意到。如果你需要在GPU上使用你的张量(对于非微不足道的PyTorch问题,你可能会这样做),那么你应该在DataLoader 中设置pin_memory=True 。这将通过让DataLoader 在锁页内存中分配空间来加速事情。

总结

回顾一下:自定义PyTorch数据加载器的有趣部分是你实现的Dataset 类。从那里,你可以得到很多漂亮的功能来简化你的数据循环。