Python3+Tensorflow 机器学习 石头剪刀布最佳实践

387 阅读1分钟

环境

  • Anaconda Jupyter
  • tensorflow 2.9
  • python 3.10

导入包

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop

数据集

数据集下载地址: www.kaggle.com/datasets/dr…

train_datagen=ImageDataGenerator(
    rescale=1./255, 
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen=ImageDataGenerator(
    rescale=1./255
)

train_generator=train_datagen.flow_from_directory(
    'C:\\Users\\jiangyihan\\Documents\\rps\\',
    target_size=(300,300),
    class_mode='categorical'
)

validation_generator=validation_datagen.flow_from_directory(
    'C:\\Users\\jiangyihan\\Documents\\rps-test-set\\rps-test-set\\',
    target_size=(300,300),
    class_mode='categorical'
)

模型

新建模型

 model=tf.keras.models.Sequential()
 model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(300,300,3)))
 model.add(tf.keras.layers.MaxPooling2D(2,2))

 model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
 model.add(tf.keras.layers.MaxPooling2D(2,2))

 model.add(tf.keras.layers.Conv2D(128, (3,3), activation='relu'))
 model.add(tf.keras.layers.MaxPooling2D(2,2))

 model.add(tf.keras.layers.Conv2D(128, (3,3), activation='relu'))
 model.add(tf.keras.layers.MaxPooling2D(2,2))

 model.add(tf.keras.layers.Flatten())
 model.add(tf.keras.layers.Dropout(0.5))
 model.add(tf.keras.layers.Dense(512, activation='relu'))
 model.add(tf.keras.layers.Dense(3, activation='softmax'))
 ## 输出查看模型结构
 model.summary()

模型编译以及优化函数

model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])

训练

## 补充模型训练结果
filePath = 'weights-improvement-{epoch:02d}-{acc:2f}.h5'
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=filePath, monitor='acc', verbose=1, save_best_only=True, mode = 'max',varbose=1)

## 开始训练
history=model.fit(train_generator, epochs=25, verbose=1, validation_data=validation_generator, callbacks=[checkpoint])

预测

import numpy as np
## C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\paper\\testpaper01-01.png 布
## C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\rock\\testrock01-25.png
## C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\\scissors\\testscissors01-25.png

path = "C:\\Users\\xxx\\Documents\\rps-test-set\\rps-test-set\\rock\\testrock01-25.png"
img=tf.keras.utils.load_img(path, target_size=(300,300))
x=tf.keras.utils.img_to_array(img)
x=np.expand_dims(x, axis=0)
images=np.vstack([x])

classes=model.predict(images, batch_size=10)
# print(fn)
print(classes)