PyTorch创建自定义数据集

237 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第6天,点击查看活动详情

前言

处理数据集的传统方法是将所有图像加载到 NumPy 数组中,当我们需要处理相对较大的数据集时,这可能会耗尽我们的计算机资源,如果计算机内存不足时,这更是不可能的,PyTorch 提供了一个强大的工具来处理大型数据集。

我们可以通过继承 PyTorch Dataset 类来创建自定义 Dataset 类,创建自定义数据集类时,需要确保定义两个基本函数:__len____getitem____len__ 函数返回数据集的长度,此函数可以被 Python len 函数调用;而 __getitem__ 函数返回指定索引处的图像。

创建自定义数据集

我们将为自定义数据集定义一个类,定义转换函数,然后使用 Dataset 类从数据集中加载图像,为数据集创建一个 PyTorch 自定义数据集。

1. 首先,加载所需的包,我们也可以使用随机种子以实现可重复性:

from PIL import Image
import torch
from torch.utils.data import Dataset
import pandas as pd
import torchvision.transforms as transforms
import os
torch.manual_seed(0)

在以上代码中,我们导入所需的包,导入 PIL 包以加载图像。使用 torch.utils.data,我们导入 Dataset 作为自定义数据集的基类;导入 pandas 包以加载 CSV 文件;此外,我们还使用 torchvision 进行数据转换。

2. 定义 histoCancerDataset 类,在 __init__ 函数中,我们接收数据集图像路径和 data_typedata_type 可以是 traintest,与所有 PyTorch 数据集一样,该类具有 __len____getitem__ 函数。__len__ 函数用于返回数据集的长度;__getitem__ 函数返回给定索引处的转换图像及其对应的标签:

class histoCancerDataset(Dataset):
    def __init__(self, data_dir, transform,data_type="train"):
        # 图像路径
        path2data=os.path.join(data_dir,data_type)
        # 获取图像列表
        filenames = os.listdir(path2data)
        # 获取完整路径
        self.full_filenames = [os.path.join(path2data, f) for f in filenames]
        # 标签文件 train_labels.csv
        csv_filename=data_type+"_labels.csv"
        path2csvLabels=os.path.join(data_dir,csv_filename)
        labels_df=pd.read_csv(path2csvLabels)
        labels_df.set_index("id", inplace=True)
        self.labels = [labels_df.loc[filename[:-4]].values[0] for filename in filenames]
        self.transform = transform

    def __len__(self):
        # 返回数据集尺寸
        return len(self.full_filenames)
        
    def __getitem__(self, idx):
        # 打开图像,应用转换,并返回标签
        image = Image.open(self.full_filenames[idx]) # PIL image
        image = self.transform(image)
        return image, self.labels[idx]

3. 接下来,定义一个将 PIL 图像转换为 PyTorch 张量的转换函数。之后,我们将对此进行扩展。此处,定义了转换函数。目前,我们只在转换函数中将 PIL 图像转换为 Pytorch 张量;

import torchvision.transforms as transforms
data_transformer = transforms.Compose([transforms.ToTensor()])

4. 然后,为 train 文件夹定义一个自定义数据集对象:

data_dir = "./data/"
histo_dataset = histoCancerDataset(data_dir, data_transformer, "train")
print(len(histo_dataset))

打印出的 histo_dataset 的长度为数据集长度。

5. 接下来,使用自定义数据集加载图像:

img,label=histo_dataset[9]
print(img.shape,torch.min(img),torch.max(img))

数据集以 (Channels, Height, Width) 格式返回图像,并且像素值被归一化为 [0.0, 1.0] 范围,这就是变换的结果。ToTensor()PIL 图像转换为 [0, 255] 范围内的 [0.0, 1.0] 范围内形状 (C x H x W)torch.FloatTensor,在 PyTorch 中处理图像时,通常使用这种格式。

相关链接

PyTorch张量操作详解

PyTorch数据加载和处理

PyTorch神经网络模型构建

PyTorch定义损失函数和优化器

PyTorch模型训练与评估