1. Pytorch加载数据
① Pytorch中加载数据需要Dataset、Dataloader。
- Dataset提供一种方式去获取每个数据及其对应的label,告诉我们总共有多少个数据。
- Dataloader为后面的网络提供不同的数据形式,它将一批一批数据进行一个打包。
2. 常用数据集两种形式
① 常用的第一种数据形式,文件夹的名称是它的label。
② 常用的第二种形式,lebel为文本格式,文本名称为图片名称,文本中的内容为对应的label。
PyTorch Dataset 教程
✅ 什么是 Dataset?
Dataset 是 PyTorch 提供的一个抽象类,定义了一个数据集应该具备的两个核心方法:
__len__():告诉 PyTorch 这个数据集有多少个样本。__getitem__(index):告诉 PyTorch 如何通过索引获取一个样本。
在训练模型时,PyTorch 会自动调用这两个函数来取数据。
✅ 使用Dataset类查看服务器下载的数据
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)# 拼接出当前类别文件夹的完整路径
# 获取该类别文件夹下所有文件名(图片名)
self.img_path = os.listdir(self.path)# 返回一个列表,列表中是该类别文件夹下所有文件名(图片名)
def __getitem__(self, idx):
img_name = self.img_path[idx] # 获取第 idx 个图片的文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 完整路径
img = Image.open(img_item_path) # 用 PIL 打开图片
label = self.label_dir # 标签是文件夹名,比如 'ants' 或 'bees'
return img, label
def __len__(self):
return len(self.img_path) # 这个类别下图片总数
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
# 创建两个子数据集,分别表示两类图片
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
print("len(ants_dataset):",len(ants_dataset)) # ants 图片数
print("len(bees_dataset):",len(bees_dataset)) # bees 图片数
# 两个 Dataset 相加,得到合并后的大数据集
train_dataset = ants_dataset + bees_dataset
print('len(train_dataset):',len(train_dataset)) # 总图片数
# 取第201张图片和标签
img, label = train_dataset[200]
print("label:", label)
img.show() # 用默认图片查看器显示图片
📌 每次 dataset[i]:
就会自动调用 __getitem__(i),返回第 i 个样本。
代码解释:
1.首先引入库,-
-
Dataset是 PyTorch 用于封装数据集的基类 -
Image用于打开图片。 -
os用于路径处理和文件操作。
2.然后重写Dataset类,进行初始化,getitem方法的重写,以及len方法
ps:在 Python 中,self 是一个习惯用法,用于类中的实例方法的第一个参数。self 指的是类的实例本身。当创建一个类的对象并调用其方法时,self 允许这些方法访问和修改对象的属性和方法。
简而言之,self 提供了一种引用当前对象的方法,使每个对象可以拥有自己的状态和行为,并与同类的其他实例分开。