环境
- Anaconda Jupyter
- tensorflow 2.9
- python 3.10
导入包
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop
数据集
数据集下载地址: www.kaggle.com/datasets/dr…
train_datagen=ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
validation_datagen=ImageDataGenerator(
rescale=1./255
)
train_generator=train_datagen.flow_from_directory(
'C:\\Users\\jiangyihan\\Documents\\rps\\',
target_size=(300,300),
class_mode='categorical'
)
validation_generator=validation_datagen.flow_from_directory(
'C:\\Users\\jiangyihan\\Documents\\rps-test-set\\rps-test-set\\',
target_size=(300,300),
class_mode='categorical'
)
模型
新建模型
model=tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(300,300,3)))
model.add(tf.keras.layers.MaxPooling2D(2,2))
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(2,2))
model.add(tf.keras.layers.Conv2D(128, (3,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(2,2))
model.add(tf.keras.layers.Conv2D(128, (3,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(2,2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(512, activation='relu'))
model.add(tf.keras.layers.Dense(3, activation='softmax'))
## 输出查看模型结构
model.summary()
模型编译以及优化函数
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
训练
## 补充模型训练结果
filePath = 'weights-improvement-{epoch:02d}-{acc:2f}.h5'
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=filePath, monitor='acc', verbose=1, save_best_only=True, mode = 'max',varbose=1)
## 开始训练
history=model.fit(train_generator, epochs=25, verbose=1, validation_data=validation_generator, callbacks=[checkpoint])
预测
import numpy as np
## C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\paper\\testpaper01-01.png 布
## C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\rock\\testrock01-25.png
## C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\\scissors\\testscissors01-25.png
path = "C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\rock\\testrock01-25.png"
img=tf.keras.utils.load_img(path, target_size=(300,300))
x=tf.keras.utils.img_to_array(img)
x=np.expand_dims(x, axis=0)
images=np.vstack([x])
classes=model.predict(images, batch_size=10)
# print(fn)
print(classes)