DataLoader

290 阅读2分钟

一、DataLoader是什么

DataLoader 是 PyTorch 中的一个核心组件,用于管理和加载数据集,以便在训练和评估模型时以批次进行高效的数据输入。它提供了处理各类数据输入的便利方式,支持数据的随机化和并行载入,从而提高训练效率。

二、DataLoader 的主要功能

  1. 批处理数据(Batching)

    • DataLoader 能够将数据集中的数据按批次加载,为训练过程提供一个稳定的数据流。
    • 通过设置 batch_size 参数,可以定义每次载入的数据样本数量,这对于 GPU 加速训练非常重要。
  2. 随机化数据(Shuffling)

    • 可以通过设置 shuffle=True 来随机化数据顺序,有助于打乱数据集以减少训练过程中的过拟合。
  3. 并行数据加载(Parallel Data Loading)

    • 通过 num_workers 参数,可以指定子进程的数量以并行执行数据加载。并行加载减少了训练过程中 CPU 数据管理的瓶颈。
  4. 数据迭代器(Iterator)

    • DataLoader 提供数据迭代器接口,使得数据可以像 Python 的迭代器一样被顺序读取和处理。

三、如何使用 DataLoader

下面是一个使用 DataLoader 的基本示例,展示了如何载入一个图像数据集:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义数据转换:图片转换为张量并归一化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集,以 MNIST 数据集为例
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# 初始化 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4)

# 数据迭代示例
for batch_idx, (data, targets) in enumerate(train_loader):
    # 在这里可以对每个批次的数据进行处理
    print(data.size(), targets.size())

四、重要参数

  1. dataset:

    • 实例化的 Dataset 对象,从中加载数据。
  2. batch_size:

    • 每个批次加载的数据样本数量。
  3. shuffle:

    • 是否对数据进行随机化。
  4. num_workers:

    • 用于数据加载的子进程数量。
    • 注意:在 Windows 系统上,设 num_workers 为 0 或者使用 multiprocessing 方法来处理数据并行问题。
  5. drop_last:

    • 如果数据集样本数量不能被 batch_size 整除,设置为 True 则丢弃最后那个不完整的批数据。

五、高级功能

  • 动态采样:

    • 可以通过自定义的 Sampler 来动态调整数据采样方式,例如根据类别平衡数据等。
  • 数据集分割:

    • 可以使用 SubsetRandomSampler 将数据集分割为训练集、验证集等。

六、服务器实操

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("pytorchstu/dataset",train=False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs , targets = data
    print(imgs.shape)
    print(targets)
    # 使用add_images来处理整个批次,或者循环处理每个图像
    writer.add_images("test_data", imgs, step)  
    step += 1
writer.close()

截图:

image.png

image.png

ps:shuttle为True代表每个epoch训练图片的顺序是不同的,一个epoch代表遍历完所有图片。