本文是入门的一些东西。如果你已经熟悉了自定义数据集加载的处理可以跳过了😂。
前言
- 我没看书进行系统学习,只是在网上找了一些demo(自定义数据集)在简单的学习和测试所了解的一些方式。
- 说自定义主要是区分下官网的demo,例如入门的手写数字mnist = tf.keras.datasets.mnist这种直接把数据集给处理好的。
- 这里说的是对图片分类的处理,分类又分为二分类(label直接0,1就可以了)和多分类(onehot),本文说下学到的方式,话不多说下面进入正题。
数据集加载
数据集加载分为两个部分:数据集、以及对应的标签
数据集分为训练集和验证集
关于标签onehot就是一维数组属于哪个分类就是标识位上标1,其余0。标签得到可以根据文件名字或者目录名进行处理
训练模型的时候数据集划分及参数说明。
model.fit(x, y, batch_size, epochs, verbose, validation_split, validation_data, validation_freq)
callback的话可以配置主动提前终止训练的一个函数EarlyStopping
刚开始的手动处理方式
用最笨的办法老老实实读图片文件一个一个处理(用java的写法了,没用脚本的高级写法)。标签我是根据文件名字来的。
def get_train_data(filePath):
train_file_name_list = os.listdir(filePath)
# 返回值
x_data = []
y_data = []
# 对每个图片单独处理
for tran_name in train_file_name_list:
tran_name_dir=os.path.join(filePath, tran_name)
if os.path.isdir(tran_name_dir):
file_name_dir=os.listdir(tran_name_dir)
for selected_train_file_name in file_name_dir:
print(selected_train_file_name)
if selected_train_file_name.endswith('.jpg'):
# 获取图片对象
captcha_image = Image.open(os.path.join(tran_name_dir, selected_train_file_name))
captcha_image = captcha_image.convert('L') # 对于简单的不用去噪,灰度反而更有利
captcha_image = captcha_image.resize((width, height)) # 把图片转换成28*28
captcha_image_np = np.array(captcha_image)
img_np = np.array(captcha_image_np)
# 把每个处理后的数据,塞进x_data,y_data
x_data.append(img_np)
y_data.append(text2vec(tran_name))
x_data = np.array(x_data)
y_data = np.array(y_data)
print("训练数据:",x_data.shape)
print("训练标签:",y_data.shape)
return x_data, y_data
onehot标签
def text2vec(text):
vector = np.zeros([char_set_len])
vector[int(text)] = 1
return vector
找VGG例子学到的from_tensor_slices
- 在学用vgg16的时候学到的。label也要自己写,主要下面的几个函数,
- glob.glob(r'\dc_2000\train*.jpg')
- tf.data.Dataset.from_tensor_slices注意要打算顺序
- label注意根据名字还是目录名字自己处理
train_image_label = [int(os.path.basename(p).split('.')[0] == 'cat') for p in train_image_path] # 文件名字,并编码,cat为1,dog为0
找GoogleNet例子学到ImageDataGenerator,labeL根据文件名来的
直接配参数处理,用到的函数ImageDataGenerator
优点的话自动图像进行了增强,例如旋转;还有label也进行了处理
def data_process_func():
# ---------------------------------- #
# 训练集进行的数据增强操作
# 1. rotation_range -> 随机旋转角度
# 2. width_shift_range -> 随机水平平移
# 3. width_shift_range -> 随机数值平移
# 4. rescale -> 数据归一化
# 5. shear_range -> 随机错切变换
# 6. zoom_range -> 随机放大
# 7. horizontal_flip -> 水平翻转
# 8. brightness_range -> 亮度变化
# 9. fill_mode -> 填充方式
# ---------------------------------- #
train_data = ImageDataGenerator(
rotation_range=50,
width_shift_range=0.1,
height_shift_range=0.1,
rescale=1 / 255.0,
shear_range=10,
zoom_range=0.1,
horizontal_flip=True,
brightness_range=(0.7, 1.3),
fill_mode='nearest'
)
# 训练器生成器
train_generator = train_data.flow_from_directory(
f'{datasets}/train',
target_size=(img_size, img_size),
batch_size=batch_size
)