还不会写神经网络?白教你利用深度学习利器Keras基于手写字数据集实现一个简单的卷积神经网络

372 阅读2分钟

我报名参加金石计划1期挑战——瓜分10万奖池,这是我的第5篇文章,点击查看活动详情

前言

手写数字识别是指建立一个学习模型,给定一系列手写数字图片和相应的数字标签,以自动识别新的手写数字图片的相应数字。图像识别是指使用计算机处理图像,通过模型分析和理解图像,并获得写入图像文件的数字。

handwriter.png 在人工智能领域,手写数字识别被转化为自动分类问题。将10个从0到9的数字分为10类。通过模型训练,我们可以对数字图像进行分类,并间接获得数字图像上的手写数字。手写字官方地址:yann.lecun.com/exdb/mnist/

依赖安装

从依赖可以看出,我们需要安装keras,tensorflow

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

准备数据

从写字描述可以知道,分类10个类别,图片大小为(28,28) ,图片的下载可以使用keras自带的api keras.datasets.mnist.load_data,其次我们对数据进行归一化,因为图像的像素值范围是[0,255],然后使用to_categorical进行one-hot编码

num_classes = 10
input_shape = (28, 28, 1)

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

建立模型

可以使用keras的Sequential类创建一个模型,如下所示

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

训练模型与评估

设置训练模型的数据批次大小以及训练的轮数,同时设置优化器

batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

总结

手写字是十分经典的数据集,可以在很多案例中看到这种demo,当然如果你会这种demo的代码了,你会发现其实使用其他的数据集你会变得更加顺手。