from torch.utils.data import Dataset
是使用 PyTorch 进行自定义数据加载的核心。Dataset 类是一个抽象类,用于创建自定义数据集
1. 为什么要使用 Dataset 类?
PyTorch 的 Dataset 类提供了:
- 标准化的数据接口:统一的数据访问方式
- 与 DataLoader 配合:实现批量加载、打乱数据、多进程加载等功能
- 灵活性:可以处理任何类型的数据(图像、文本、音频等)
2. 基本用法:创建自定义数据集
你需要创建一个继承自 Dataset 的子类,并实现三个重要方法:
from torch.utils.data import Dataset
import torch
class CustomDataset(Dataset):
def __init__(self, data, labels):
"""初始化数据集
Args:
data: 输入数据
labels: 对应的标签
"""
self.data = data
self.labels = labels
def __len__(self):
"""返回数据集的大小"""
return len(self.data)
def __getitem__(self, idx):
"""根据索引获取单个样本"""
sample = self.data[idx]
label = self.labels[idx]
# 通常需要将数据转换为 tensor
sample = torch.tensor(sample, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.long)
return sample, label
3. 实际示例:处理图像数据
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
class ImageDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
if f.endswith(('.png', '.jpg', '.jpeg'))]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
# 假设文件名包含标签信息,如 "cat_001.jpg", "dog_005.jpg"
label = 0 if 'cat' in os.path.basename(img_path) else 1
return image, torch.tensor(label)
4. 与 DataLoader 配合使用
from torch.utils.data import DataLoader
# 创建数据集实例
dataset = CustomDataset(data, labels)
# 创建 DataLoader
dataloader = DataLoader(
dataset,
batch_size=32, # 批量大小
shuffle=True, # 是否打乱数据
num_workers=2 # 使用多进程加载数据
)
# 在训练循环中使用
for batch_data, batch_labels in dataloader:
# 训练代码...
pass
5. 常用的内置数据集
PyTorch 也提供了一些内置的数据集:
from torch.utils.data import Dataset
from torchvision import datasets
# 示例:加载 MNIST 数据集
mnist_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=None
)