自定义数据分类模型基本流程(一)

616 阅读5分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

我们前面提到到了MNIST数据集和Cifar10数据集(还有Cifar100,只是分类的类别不一样而已),当然还有其他数据集,但是在实际项目开发中,或者科研中,我们往往并不是使用这些公开的的数据集,很大可能我们是选择使用自己的数据,那么这个就涉及到几个问题

    1. 我们应该怎么去预处理?
    1. 怎么加载这些数据集?
    1. 怎么使用自己的数据集去训练自己的模型?

1、认识自定义数据集的基础类 Dataset

import torch
import os, glob
import random, csv

from torch.utils.data import Dataset

class Pokemon(Dataset):
    def __init__(self):
        super(Pokemon, self).__init__()
        
    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

分析: 这个类是什么作用呢,在这里我们可以理解,我们把自己的数据集,通过这个类来处理后,就能够想MNIST,Cifar那种方便的使用,可能有人对上面的类很陌生,接下来我们一步步分析

  • init 构造函数,这个大家都知道,在这个函数我们一般初始化路径,读取图片,等初始化操作
  • len 这个函数是拿到数据集的长度
  • getitem 这个函数是根据索引拿到一条数据

上面的pass是不让程序报错提示,不懂的可以搜下这个关键字

1、数据加载预处理

通过下面的例子可以知道,在构造函数中,我们遍历文件夹,每个类别一个文件夹,我们通过这个规律将每个类别赋值为唯一的一个标签,

class Pokemon(Dataset):
    """
    root:根目录
    resize : 设置图片的大小
    mode : train 还是 test
    """

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.mode = mode

        self.name2lable = {}  # 每个类别对应的标签
        for name in sorted(os.listdir(os.path.join(root))):
            # 如果不是目录就继续
            if not os.path.isdir(os.path.join(root, name)):
                continue
            # 给每个类别赋值一个标签
            self.name2lable[name] = len(self.name2lable)
        print(self.name2lable)

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass


if __name__ == '__main__':
    db = Pokemon("pokeman", 224, "train")

image.png

2、数据读取

数据读取我们写在__init__函数当中,这里使用csv文件来保存数据,并且在这里我们把数据分成三个部分,分别是 “train”、“val”、“test”,并且根据三个部分数据使用目的进行划分,如下所示

def __init__(self, root, resize, mode):
    super(Pokemon, self).__init__()
    self.root = root
    self.resize = resize
    self.mode = mode

    self.name2lable = {}  # 每个类别对应的标签
    for name in sorted(os.listdir(os.path.join(root))):
        # 如果不是目录就继续
        if not os.path.isdir(os.path.join(root, name)):
            continue
        # 给每个类别赋值一个标签
        self.name2lable[name] = len(self.name2lable)
    self.images, self.labels = self.load_csv("image.csv")

    if mode == "train":
        self.images = self.images[:int(0.6 * len(self.images))]
        self.labels = self.labels[:int(0.6 * len(self.images))]
    elif mode == "val":
        self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
        self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
    else:
        self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
        self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]

我们可以查看保存在csv中的数据文件

image.png 可以看到,我们保存数据的格式是 图片的地址+图片的类别

3、数据处理

上面我们提到了如何使用读取数据,下面简单将数据处理的部分讲一下,如下所示,首先是将不同格式的文件读取到file_name中,也就是CSV文件当中去,然后读取数据到images, labels两个容器当中,最后将容器里面所有的数据return

def load_csv(self, file_name):
    # 如果文件不存在我们就创建这个文件,否则不用重新创建
    if not os.path.exists(os.path.join(self.root, file_name)):
        images = []
        for name in self.name2lable.keys():
            images += glob.glob(os.path.join(self.root, name, "*.png"))
            images += glob.glob(os.path.join(self.root, name, "*.jpg"))
            images += glob.glob(os.path.join(self.root, name, "*.jpeg"))
        # 1167 'pokeman\bulbasaur\00000000.png'
        random.shuffle(images)
        print(len(images), images)
        with open(os.path.join(self.root, file_name), mode="w", newline="") as f:
            writer = csv.writer(f)
            for img in images:
                name = img.split(os.sep)[-2]
                label = self.name2lable[name]
                writer.writerow([img, label])
  
    images, labels = [], []
    with open(os.path.join(self.root, file_name)) as f:
        reader = csv.reader(f)
        for row in reader:
            img, local_label = row
            local_label = int(local_label)

            images.append(img)
            labels.append(local_label)
    assert len(images) == len(labels)
    return images, labels

4、数据打包

在__len__函数中,我们返回数据的长度,在__getitem__中我们将数据的图片地址和标签都转化成tensor的格式并返回,在这里我们使用了transforms对数据进行了简单的处理,如Resize和旋转,并也转成了Tensor的格式。

def __len__(self):
    return len(self.images)

def __getitem__(self, idx):
    img, lab = self.images[idx], self.labels[idx]
    tf = transforms.Compose([
        lambda x: Image.open(x).convert("RGB"),
        transforms.Resize((int(self.resize*1.3), int(self.resize*1.3))),
        transforms.RandomRotation(15), # 旋转15度
        transforms.ToTensor()
    ])
    img = tf(img)
    label = torch.tensor(lab)
    return img, label

简单看下旋转前后的区别,当然我们改变Resize,,其大小也是会改变的

image.png

image.png

5、数据可视化展示

上面我们已经展示了旋转后的图片,那么这个在Pytroch平台上怎么展示么,其实这个是一个可视化平台,名为Visdom,如果你还没有安装这个库的话,可以使用pip install visdom来安装,安装完成后启动的方式就是一个命令 python -m visdom.server

image.png 当看到这样的情况,就说明你已经完成了,接下来就是用Visdom,这里举一个简单的例子

if __name__ == '__main__':
    import visdom

    viz = visdom.Visdom()

    db = Pokemon("pokeman", 64, "train")

    x, y = next(iter(db))
    print("sample", x.shape, y.shape, y)
    viz.image(x, win="sample_x", opts=dict(title="小动物"))

我们这里使用Visdom的展示图片的函数方法,x为数据,这个使用的方法很多,例如下面的例子,画直线,曲线等等,具体的使用,读者可以专门学习一下,内容很多,这里仅仅使用就好

vis.image(np.ones((3, 100, 100))) # 绘制一幅尺寸为3 * 100 * 100的图片,图片的像素值全部为1
vis.line(X=x,Y=y,win='sin(x)',opts=dict(showlegend=True))

那么展示到这里,数据的处理打包基本完成了,封装了一个可以读取数据的类,接下来就是使用这些数据来训练就可以可以了。

时间比较晚了,改天在下一篇文章中更新后续内容