本文已参与「新人创作礼」活动,一起开启掘金创作之路。
一般的深度学习训练模型的搭建框架过程为,导入数据-建立模型-训练与测试/迁移学习,在这篇笔记中,我主要记录了自定义一个自己的数据集过程与迁移学习的方法。对于其中涉及的到的训练过程与测试过程在其他的笔记中已有提到。
对于之前用到的MNIST数据集与Cifar10数据集的导入,其实我们都只是利用了pytorch提供的函数,分别是torchvision.datasets.MNIST与torchvision.datasets.CIFAR10两个函数帮助我们实现了样本数据的导入。但是,当我们需要训练我们自己的数据集时,具体的datasets操作函数便需要我们来编写。
对于我们设计自定义数据集类时,具体有三个步骤:
- 继承torch.utils.data中的Dataset类
- 编写 __ len __ ()函数
- 编写 __ getitem __ ()函数
源码中的Dataset如下:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
所以一个最基本的数据集类模型的原始为
class Pokemon(dataset):
# 定义初始化函数
def __init__(self):
pass
# 定义返回样本数量函数:实现返回具体样本的数目
def __len__(self):
pass
# 定义返回具体样本的函数:实现读取一个具体的样本
def __getitem__(self, item):
pass
完善后的参考代码:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader # 注意是Dataset而不是dataset
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.utils import save_image
from visdom import Visdom
import os, glob, random, csv, time
class Pokemon(Dataset):
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.image = []
self.label = []
# 创建一个字典存储类别与标签
self.name2label = {}
for name in os.listdir(root):
# 判断文件名是否为目录
if not os.path.isdir(os.path.join(root, name)):
continue
# 关键字的取值为当前的关键字个数
self.name2label[name] = len(self.name2label.keys())
# print(self.name2label.keys())
# dict_keys(['bulbasaur', 'charmander', 'mewtwo', 'pikachu'])
# print(self.name2label)
# {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
# 导入图像数据
# self.load_csv('images.csv')
self.image, self.label = self.load_csv('images.csv')
# 设置train-val-test比例
# nums: 700
if mode == 'train':
self.image = self.image[:int(0.6 * len(self.image))]
self.label = self.label[:int(0.6 * len(self.label))]
# nums: 233
elif mode == 'val':
self.image = self.image[int(0.6 * len(self.image)):int(0.8 * len(self.image))]
self.label = self.label[int(0.6 * len(self.label)):int(0.8 * len(self.label))]
# nums: 234
elif mode == 'test':
self.image = self.image[int(0.8 * len(self.image)):]
self.label = self.label[int(0.8 * len(self.label)):]
else:
print("Error! 'Mode' has no such mode choice!")
def __len__(self):
return len(self.image)
def __getitem__(self, item):
# item = self.__len__()
# print(" item: ", item)
image = self.image[item]
label = self.label[item]
# print("image: ",image,"label: ",label)
# 对图像进行预处理
transform = transforms.Compose([
# 转换为RGB图像
lambda x: Image.open(x).convert('RGB'),
# 重新确定尺寸
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
# 旋转角度
transforms.RandomRotation(15),
# 中心裁剪
transforms.CenterCrop(self.resize),
# 转换为Tensor格式
transforms.ToTensor(),
# 使数据分布在0附近
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image)
# label是int形式,转换为tensor格式
label = torch.tensor(label)
return image, label
# 导入csv样本数据
def load_csv(self, csv_file):
# 当没有csv数据文件时创建文件, 将数据集信息保存在一个csv_file文件中
if not os.path.exists(os.path.join(self.root, csv_file)):
# 用来存储图像路径信息
image = []
# 现查找数据集文件中的png,jpg,jpeg格式的全部图像,路径全部保存在image中
for name in self.name2label.keys():
# glob 模块用于查找符合特定规则的文件路径名
image += glob.glob(os.path.join(self.root, name, '*.png'))
image += glob.glob(os.path.join(self.root, name, '*.jpg'))
image += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# 'E:\\学习\\机器学习\\数据集\\pokemon\\bulbasaur\\00000000.png'
# print(image, len(image))
# 随机打乱图像
random.shuffle(image)
# 截取绝对路径下的图像名字
# name = next(iter(image))
# name = name.split('\\')[-2]
# print(name) # charmander
# 读写打开文件, 注意newline=''是为了不让存储的时候回车两行
with open(csv_file, mode='w', newline='') as f:
# 创建 csv 对象
writer = csv.writer(f)
for img in image:
# split: 对路径进行分割,以列表形式返回
# os.sep: 当前操作系统所使用的路径分隔符 windows->'\' linux 和 unix->'/'
# ['E:/学习/机器学习/数据集/pokemon', 'pikachu', '00000179.jpg']
# [-2]既提取了文件夹名字: 'pikachu'
name = img.split(os.sep)[-2]
label = self.name2label[name]
# 写入一行或多行数据
# 形式: E:\学习\机器学习\数据集\pokemon\charmander\00000185.png,1
writer.writerow([img, label])
# print('writen into csv file:', csv_file)
# 打开csv文件读取信息
with open(csv_file) as f:
# 创建两个list存储图像名字与标签
image = []
label = []
# #创建 csv 对象,它是一个包含所有数据的列表,每一行为名字与标签,eg: charmander,1
reader = csv.reader(f)
# 循环赋值各行内容
for row in reader:
# 导入数据, 若没有设置newline=''会报错,因为回车了两行
image.append(row[0])
label.append(int(row[1]))
# print(len(image), len(label))
if len(image) == len(label):
return image, label
else:
print("Error! len(image) != len(label) !")
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std = mean
# x: [c, h, w]
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
def plot_image(img):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.xticks([])
plt.yticks([])
plt.show()
root = 'E:\学习\机器学习\数据集\pokemon'
viz = Visdom()
train_data = Pokemon(root=root, resize=64, mode='train')
# print(train_data.__len__())
# image, label = next(iter(train_data))
# print(image.shape, label)
# 利用DataLoader加载数据集
data = DataLoader(train_data, batch_size=64, shuffle=True)
# 测试
for epochodx, (image, label) in enumerate(data):
# plot_image(train_data.denormalize(image))
# time.sleep(5)
# 保存图像在本地
save_image(image, os.path.join('sample', 'image-{}.png'.format(epochodx + 1)), nrow=8, normalize=True)
# 可视化操作
# viz.images(image, nrow=8, win='batch', opts=dict(title='batch'))
viz.images(train_data.denormalize(image), nrow=8, win='batch', opts=dict(title='batch'))
time.sleep(5)
在visdom中输出的结果是有点奇怪的
但是保存在本地的图像是可以正常显示的
大图
原因:因为我们本身对数据集进行的transforms操作中,包含了Normalize的操作,数据分布变成了在0附近的一个分布,这就代表有些数值是小于0的,而visdom本身只能显示大于0以上的像素,所以会出现这种情况,现在只需要将图像进行Denormalize操作即可:
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std = mean
# x: [c, h, w]
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 可视化操作作以下更改
# viz.images(image, nrow=8, win='batch', opts=dict(title='batch'))
viz.images(train_data.denormalize(image), nrow=8, win='batch', opts=dict(title='batch'))
可以看见,图像变得正常许多
到此,我们实现了自定义数据集的加载操作,其中image.csv文件中数据存储的格式为图像的对应路径与便签,如图所示: