Pytorch数据加载

30 阅读1分钟
from torch.utils.data import Dataset

是使用 PyTorch 进行自定义数据加载的核心。Dataset 类是一个抽象类,用于创建自定义数据集

1. 为什么要使用 Dataset 类?

PyTorch 的 Dataset 类提供了:

  • 标准化的数据接口:统一的数据访问方式
  • 与 DataLoader 配合:实现批量加载、打乱数据、多进程加载等功能
  • 灵活性:可以处理任何类型的数据(图像、文本、音频等)

2. 基本用法:创建自定义数据集

你需要创建一个继承自 Dataset 的子类,并实现三个重要方法:

from torch.utils.data import Dataset
import torch

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        """初始化数据集
        Args:
            data: 输入数据
            labels: 对应的标签
        """
        self.data = data
        self.labels = labels
    
    def __len__(self):
        """返回数据集的大小"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """根据索引获取单个样本"""
        sample = self.data[idx]
        label = self.labels[idx]
        
        # 通常需要将数据转换为 tensor
        sample = torch.tensor(sample, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        
        return sample, label

3. 实际示例:处理图像数据

from torch.utils.data import Dataset
from PIL import Image
import os
import torch

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) 
                           if f.endswith(('.png', '.jpg', '.jpeg'))]
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # 假设文件名包含标签信息,如 "cat_001.jpg", "dog_005.jpg"
        label = 0 if 'cat' in os.path.basename(img_path) else 1
        
        return image, torch.tensor(label)

4. 与 DataLoader 配合使用

from torch.utils.data import DataLoader

# 创建数据集实例
dataset = CustomDataset(data, labels)

# 创建 DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=32,      # 批量大小
    shuffle=True,       # 是否打乱数据
    num_workers=2       # 使用多进程加载数据
)

# 在训练循环中使用
for batch_data, batch_labels in dataloader:
    # 训练代码...
    pass

5. 常用的内置数据集

PyTorch 也提供了一些内置的数据集:

from torch.utils.data import Dataset
from torchvision import datasets

# 示例:加载 MNIST 数据集
mnist_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=None
)