Tensorflow2自定义图片分类加载的一些方式

1,596 阅读3分钟

本文是入门的一些东西。如果你已经熟悉了自定义数据集加载的处理可以跳过了😂。

前言

  1. 我没看书进行系统学习,只是在网上找了一些demo(自定义数据集)在简单的学习和测试所了解的一些方式。
  2. 说自定义主要是区分下官网的demo,例如入门的手写数字mnist = tf.keras.datasets.mnist这种直接把数据集给处理好的。
  3. 这里说的是对图片分类的处理,分类又分为二分类(label直接0,1就可以了)和多分类(onehot),本文说下学到的方式,话不多说下面进入正题。

数据集加载

数据集加载分为两个部分:数据集、以及对应的标签

数据集分为训练集和验证集

关于标签onehot就是一维数组属于哪个分类就是标识位上标1,其余0。标签得到可以根据文件名字或者目录名进行处理

训练模型的时候数据集划分及参数说明。

model.fit(x, y, batch_size, epochs, verbose, validation_split, validation_data, validation_freq)

1710563031764.png

callback的话可以配置主动提前终止训练的一个函数EarlyStopping

image.png

刚开始的手动处理方式

用最笨的办法老老实实读图片文件一个一个处理(用java的写法了,没用脚本的高级写法)。标签我是根据文件名字来的。

def get_train_data(filePath):
    train_file_name_list = os.listdir(filePath)
    # 返回值
    x_data = []
    y_data = []
    # 对每个图片单独处理
    for tran_name in train_file_name_list:
        tran_name_dir=os.path.join(filePath, tran_name)
        if os.path.isdir(tran_name_dir):
            file_name_dir=os.listdir(tran_name_dir)
            for selected_train_file_name in file_name_dir:
                print(selected_train_file_name)
                if selected_train_file_name.endswith('.jpg'):
                    # 获取图片对象
                    captcha_image = Image.open(os.path.join(tran_name_dir, selected_train_file_name))
                    captcha_image = captcha_image.convert('L')  # 对于简单的不用去噪,灰度反而更有利
                    captcha_image = captcha_image.resize((width, height)) # 把图片转换成28*28
                    captcha_image_np = np.array(captcha_image)
                    img_np = np.array(captcha_image_np)
                    # 把每个处理后的数据,塞进x_data,y_data
                    x_data.append(img_np)
                    y_data.append(text2vec(tran_name))
    x_data = np.array(x_data)
    y_data = np.array(y_data)
    print("训练数据:",x_data.shape)
    print("训练标签:",y_data.shape)
    return x_data, y_data

onehot标签

def text2vec(text):
    vector = np.zeros([char_set_len])
    vector[int(text)] = 1
    return vector

找VGG例子学到的from_tensor_slices

  1. 在学用vgg16的时候学到的。label也要自己写,主要下面的几个函数,
  2. glob.glob(r'\dc_2000\train*.jpg')
  3. tf.data.Dataset.from_tensor_slices注意要打算顺序
  4. label注意根据名字还是目录名字自己处理
train_image_label = [int(os.path.basename(p).split('.')[0] == 'cat') for p in train_image_path]  # 文件名字,并编码,cat为1,dog为0

找GoogleNet例子学到ImageDataGenerator,labeL根据文件名来的

直接配参数处理,用到的函数ImageDataGenerator

优点的话自动图像进行了增强,例如旋转;还有label也进行了处理

def data_process_func():
    # ---------------------------------- #
    #   训练集进行的数据增强操作
    #   1. rotation_range -> 随机旋转角度
    #   2. width_shift_range -> 随机水平平移
    #   3. width_shift_range -> 随机数值平移
    #   4. rescale -> 数据归一化
    #   5. shear_range -> 随机错切变换
    #   6. zoom_range -> 随机放大
    #   7. horizontal_flip -> 水平翻转
    #   8. brightness_range -> 亮度变化
    #   9. fill_mode -> 填充方式
    # ---------------------------------- #
    train_data = ImageDataGenerator(
        rotation_range=50,
        width_shift_range=0.1,
        height_shift_range=0.1,
        rescale=1 / 255.0,
        shear_range=10,
        zoom_range=0.1,
        horizontal_flip=True,
        brightness_range=(0.7, 1.3),
        fill_mode='nearest'
    )
   # 训练器生成器
   train_generator = train_data.flow_from_directory(
    f'{datasets}/train',
    target_size=(img_size, img_size),
    batch_size=batch_size
)