利用keras.preprocessing来读取文件

212 阅读1分钟

import tensorflow as tf import tensorflow.keras as keras import matplotlib.pyplot as plt import os import PIL import pathlib def datapreprocess(data_dir,batch_size,img_height,img_width): #读取文件,查看相关的属性 data_dir=pathlib.Path(data_dir) image_count=len(list(data_dir.glob('/'))) #读取文件夹下面所有的文件 print('一共有有{}张图像'.format(image_count))

-----------------------------------------------------------------------

# 查看文件有多少类,并输出第一个类文件中的第一张图像
# 且查看第一张图像的size和通道数
name_list=[]
for item in data_dir.iterdir():
    name_list.append(item.name)
print('有如下类别{}共{}类'.format(name_list,len(name_list)))
ph1_dir=os.path.join(data_dir,name_list[0])
ph1_dir=pathlib.Path(ph1_dir)
ph1=list(ph1_dir.glob('*/'))
p=PIL.Image.open(ph1[0])
plt.imshow(p)
plt.colorbar()
plt.show()
#查看图像的shape
p=tf.io.read_file(str(list(data_dir.glob('*/*'))[0]))
p_tensor=tf.image.decode_image(p)
print(p_tensor.shape)

-----------------------------------------------------------------------

#利用keras.preprocessing来创建数据集
# 因为默认的color_mode参数是‘rgb’,所以如果是灰度图则需要将color_mode改为‘grayscle’
train_ds=keras.preprocessing.image_dataset_from_directory(
    data_dir,validation_split=0.2,subset='training',seed=111,
    image_size=[img_height,img_width],color_mode='grayscale',batch_size=batch_size)
print(train_ds)
val_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir, validation_split=0.2, subset='validation', seed=111,
    image_size=[img_height, img_width],color_mode='grayscale',batch_size=batch_size
)
print(val_ds)

datapreprocess('data-class',32,224,224)