PyTorch开箱即有强大的数据加载功能。但是,强大的能力伴随着巨大的责任,这使得PyTorch的数据加载成为一个相当高级的话题。
学习高级课题的最好方法之一就是从快乐的路径开始。然后在你发现你需要的时候再增加复杂性。让我们通过一个快速入门的例子。
什么是PyTorch DataLoader?
PyTorchDataLoader
类为你提供了一个可迭代的Dataset
。它很有用,因为它可以并行加载数据,自动洗牌和批处理单个样本,所有这些都是开箱即用。这就为你设置了一个非常简单的训练循环。
PyTorch数据集
但是要创建一个DataLoader
,你必须先有一个Dataset
,这个类负责将样本实际读入内存。当你实现一个DataLoader
,Dataset
,几乎所有有趣的逻辑都会在这里进行。
有两种风格的Dataset
类,地图式和可迭代式。地图风格的Datasets
,更常见,也更直接,所以我们将专注于它们。
要创建一个地图风格的Dataset
类,你需要实现两个方法:__getitem__()
和__len__()
。__len__()
方法返回数据集中的样本总数,__getitem__()
方法获取一个索引并返回该索引处的样本。
PyTorch的Dataset
对象非常灵活--它们可以返回你想要的任何类型的tensor(s)。但是有监督的训练数据集通常应该返回一个输入张量和一个标签。为了说明问题,让我们创建一个数据集,输入张量是一个3×3的矩阵,索引在对角线上。标签将是索引。
它应该看起来像这样。
dataset[3]
# Expected result
# {'x': array([[3., 0., 0.],
# [0., 3., 0.],
# [0., 0., 3.]]),
# 'y': 3}
记住,我们所要实现的是__getitem__()
和__len__()
。
from typing import Dict, Union
import numpy as np
import torch
class ToyDataset(torch.utils.data.Dataset):
def __init__(self, size: int):
self.size = size
def __len__(self) -> int:
return self.size
def __getitem__(self, index: int) -> Dict[str, Union[int, np.ndarray]]:
return dict(
x=np.eye(3) * index,
y=index,
)
非常简单。我们可以实例化这个类并开始访问各个样本。
dataset = ToyDataset(10)
dataset[3]
# Expected result
# {'x': array([[3., 0., 0.],
# [0., 3., 0.],
# [0., 0., 3.]]),
# 'y': 3}
如果碰巧在处理图像数据,__getitem__()
可能是一个放置你的TorchVision变换的好地方。
在这一点上,一个样本是一个dict
,其中"x"
是一个矩阵,形状是(3, 3)
,"y"
是一个Python整数。但是我们想要的是成批的数据。"x"
应该是一个PyTorch张量,形状为(batch_size, 3, 3)
,"y"
应该是一个张量,形状为batch_size
。这就是DataLoader
回来的地方。
PyTorch数据加载器
为了迭代成批的样本,将你的Dataset
对象传递给一个DataLoader
。
torch.manual_seed(1234)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=3,
shuffle=True,
num_workers=2,
)
for batch in loader:
print(batch["x"].shape, batch["y"])
# Expected result
# torch.Size([3, 3, 3]) tensor([2, 1, 3])
# torch.Size([3, 3, 3]) tensor([6, 7, 9])
# torch.Size([3, 3, 3]) tensor([5, 4, 8])
# torch.Size([1, 3, 3]) tensor([0])
请注意这里发生的几件事。
- NumPy数组和Python整数都被转换为PyTorch的张量。
- 尽管我们在
ToyDataset
中获取单个样本,但DataLoader
会自动为我们批量处理这些样本,并按照我们要求的批量大小进行处理。即使单个样本是在dict结构中,这也是有效的。如果你返回图元,这也是有效的。 - 样本是随机洗牌的。我们通过设置
torch.manual_seed(1234)
来保持可重复性。 - 样本是跨进程并行读取的。事实上,如果你在Jupyter笔记本中运行这段代码会失败。为了让它工作,你需要把它放在Python脚本中的
if __name__ == "__main__":
检查之下。
还有一件事,我在这个例子中没有做,但你应该注意到。如果你需要在GPU上使用你的张量(对于非微不足道的PyTorch问题,你可能会这样做),那么你应该在DataLoader
中设置pin_memory=True
。这将通过让DataLoader
在锁页内存中分配空间来加速事情。
总结
回顾一下:自定义PyTorch数据加载器的有趣部分是你实现的Dataset
类。从那里,你可以得到很多漂亮的功能来简化你的数据循环。