一、DataLoader是什么
DataLoader 是 PyTorch 中的一个核心组件,用于管理和加载数据集,以便在训练和评估模型时以批次进行高效的数据输入。它提供了处理各类数据输入的便利方式,支持数据的随机化和并行载入,从而提高训练效率。
二、DataLoader 的主要功能
-
批处理数据(Batching)
DataLoader能够将数据集中的数据按批次加载,为训练过程提供一个稳定的数据流。- 通过设置
batch_size参数,可以定义每次载入的数据样本数量,这对于 GPU 加速训练非常重要。
-
随机化数据(Shuffling)
- 可以通过设置
shuffle=True来随机化数据顺序,有助于打乱数据集以减少训练过程中的过拟合。
- 可以通过设置
-
并行数据加载(Parallel Data Loading)
- 通过
num_workers参数,可以指定子进程的数量以并行执行数据加载。并行加载减少了训练过程中 CPU 数据管理的瓶颈。
- 通过
-
数据迭代器(Iterator)
DataLoader提供数据迭代器接口,使得数据可以像 Python 的迭代器一样被顺序读取和处理。
三、如何使用 DataLoader
下面是一个使用 DataLoader 的基本示例,展示了如何载入一个图像数据集:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义数据转换:图片转换为张量并归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载数据集,以 MNIST 数据集为例
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# 初始化 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4)
# 数据迭代示例
for batch_idx, (data, targets) in enumerate(train_loader):
# 在这里可以对每个批次的数据进行处理
print(data.size(), targets.size())
四、重要参数
-
dataset:
- 实例化的 Dataset 对象,从中加载数据。
-
batch_size:
- 每个批次加载的数据样本数量。
-
shuffle:
- 是否对数据进行随机化。
-
num_workers:
- 用于数据加载的子进程数量。
- 注意:在 Windows 系统上,设
num_workers为 0 或者使用multiprocessing方法来处理数据并行问题。
-
drop_last:
- 如果数据集样本数量不能被
batch_size整除,设置为 True 则丢弃最后那个不完整的批数据。
- 如果数据集样本数量不能被
五、高级功能
-
动态采样:
- 可以通过自定义的
Sampler来动态调整数据采样方式,例如根据类别平衡数据等。
- 可以通过自定义的
-
数据集分割:
- 可以使用
SubsetRandomSampler将数据集分割为训练集、验证集等。
- 可以使用
六、服务器实操
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("pytorchstu/dataset",train=False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
imgs , targets = data
print(imgs.shape)
print(targets)
# 使用add_images来处理整个批次,或者循环处理每个图像
writer.add_images("test_data", imgs, step)
step += 1
writer.close()
截图:
ps:shuttle为True代表每个epoch训练图片的顺序是不同的,一个epoch代表遍历完所有图片。