本文已参与「新人创作礼」活动,一起开启掘金创作之路。
我们前面提到到了MNIST数据集和Cifar10数据集(还有Cifar100,只是分类的类别不一样而已),当然还有其他数据集,但是在实际项目开发中,或者科研中,我们往往并不是使用这些公开的的数据集,很大可能我们是选择使用自己的数据,那么这个就涉及到几个问题
-
- 我们应该怎么去预处理?
-
- 怎么加载这些数据集?
-
- 怎么使用自己的数据集去训练自己的模型?
1、认识自定义数据集的基础类 Dataset
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset
class Pokemon(Dataset):
def __init__(self):
super(Pokemon, self).__init__()
def __len__(self):
pass
def __getitem__(self, idx):
pass
分析: 这个类是什么作用呢,在这里我们可以理解,我们把自己的数据集,通过这个类来处理后,就能够想MNIST,Cifar那种方便的使用,可能有人对上面的类很陌生,接下来我们一步步分析
- init 构造函数,这个大家都知道,在这个函数我们一般初始化路径,读取图片,等初始化操作
- len 这个函数是拿到数据集的长度
- getitem 这个函数是根据索引拿到一条数据
上面的pass是不让程序报错提示,不懂的可以搜下这个关键字
1、数据加载预处理
通过下面的例子可以知道,在构造函数中,我们遍历文件夹,每个类别一个文件夹,我们通过这个规律将每个类别赋值为唯一的一个标签,
class Pokemon(Dataset):
"""
root:根目录
resize : 设置图片的大小
mode : train 还是 test
"""
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.mode = mode
self.name2lable = {} # 每个类别对应的标签
for name in sorted(os.listdir(os.path.join(root))):
# 如果不是目录就继续
if not os.path.isdir(os.path.join(root, name)):
continue
# 给每个类别赋值一个标签
self.name2lable[name] = len(self.name2lable)
print(self.name2lable)
def __len__(self):
pass
def __getitem__(self, idx):
pass
if __name__ == '__main__':
db = Pokemon("pokeman", 224, "train")
2、数据读取
数据读取我们写在__init__函数当中,这里使用csv文件来保存数据,并且在这里我们把数据分成三个部分,分别是 “train”、“val”、“test”,并且根据三个部分数据使用目的进行划分,如下所示
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.mode = mode
self.name2lable = {} # 每个类别对应的标签
for name in sorted(os.listdir(os.path.join(root))):
# 如果不是目录就继续
if not os.path.isdir(os.path.join(root, name)):
continue
# 给每个类别赋值一个标签
self.name2lable[name] = len(self.name2lable)
self.images, self.labels = self.load_csv("image.csv")
if mode == "train":
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.images))]
elif mode == "val":
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
我们可以查看保存在csv中的数据文件
可以看到,我们保存数据的格式是 图片的地址+图片的类别
3、数据处理
上面我们提到了如何使用读取数据,下面简单将数据处理的部分讲一下,如下所示,首先是将不同格式的文件读取到file_name中,也就是CSV文件当中去,然后读取数据到images, labels两个容器当中,最后将容器里面所有的数据return
def load_csv(self, file_name):
# 如果文件不存在我们就创建这个文件,否则不用重新创建
if not os.path.exists(os.path.join(self.root, file_name)):
images = []
for name in self.name2lable.keys():
images += glob.glob(os.path.join(self.root, name, "*.png"))
images += glob.glob(os.path.join(self.root, name, "*.jpg"))
images += glob.glob(os.path.join(self.root, name, "*.jpeg"))
# 1167 'pokeman\bulbasaur\00000000.png'
random.shuffle(images)
print(len(images), images)
with open(os.path.join(self.root, file_name), mode="w", newline="") as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2lable[name]
writer.writerow([img, label])
images, labels = [], []
with open(os.path.join(self.root, file_name)) as f:
reader = csv.reader(f)
for row in reader:
img, local_label = row
local_label = int(local_label)
images.append(img)
labels.append(local_label)
assert len(images) == len(labels)
return images, labels
4、数据打包
在__len__函数中,我们返回数据的长度,在__getitem__中我们将数据的图片地址和标签都转化成tensor的格式并返回,在这里我们使用了transforms对数据进行了简单的处理,如Resize和旋转,并也转成了Tensor的格式。
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img, lab = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x: Image.open(x).convert("RGB"),
transforms.Resize((int(self.resize*1.3), int(self.resize*1.3))),
transforms.RandomRotation(15), # 旋转15度
transforms.ToTensor()
])
img = tf(img)
label = torch.tensor(lab)
return img, label
简单看下旋转前后的区别,当然我们改变Resize,,其大小也是会改变的
5、数据可视化展示
上面我们已经展示了旋转后的图片,那么这个在Pytroch平台上怎么展示么,其实这个是一个可视化平台,名为Visdom,如果你还没有安装这个库的话,可以使用pip install visdom来安装,安装完成后启动的方式就是一个命令 python -m visdom.server
当看到这样的情况,就说明你已经完成了,接下来就是用Visdom,这里举一个简单的例子
if __name__ == '__main__':
import visdom
viz = visdom.Visdom()
db = Pokemon("pokeman", 64, "train")
x, y = next(iter(db))
print("sample", x.shape, y.shape, y)
viz.image(x, win="sample_x", opts=dict(title="小动物"))
我们这里使用Visdom的展示图片的函数方法,x为数据,这个使用的方法很多,例如下面的例子,画直线,曲线等等,具体的使用,读者可以专门学习一下,内容很多,这里仅仅使用就好
vis.image(np.ones((3, 100, 100))) # 绘制一幅尺寸为3 * 100 * 100的图片,图片的像素值全部为1
vis.line(X=x,Y=y,win='sin(x)',opts=dict(showlegend=True))
那么展示到这里,数据的处理打包基本完成了,封装了一个可以读取数据的类,接下来就是使用这些数据来训练就可以可以了。
时间比较晚了,改天在下一篇文章中更新后续内容