基于TensorFlow的mnist手写体识别

118 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

0 引言

昨天我做(学者网上的教程)了一个线性回归的模型,可以参考这篇博客,用的TensorFlow框架,今天我继续学习,用TensorFlow框架对mnist数据集进行手写体识别。

1 准备数据

这里用到的是TensorFlow里面的placeholder占位符,类似constant,只不过先定义但是不赋值,用起来的时候再赋值。

  • mnist数据集
    from tensorflow.examples.tutorials.mnist import input_data mnist_data = input_data.read_data_sets("./mnist_data", one_hot=True)
  • y_true
    y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10], name="y_true")
  • label
    X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name="X")

2 构造模型

  • 参数
    weights = tf.Variable(initial_value=tf.random_normal(shape=[784,10]),name="weight") bias = tf.Variable(initial_value=tf.random_normal([10]),name="bias")
  • 模型
    y_predict = tf.matmul(X,weights) + bias

3 构造损失函数

  • loss function loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))

4 优化损失

  • 梯度下降法 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

5 计算准确率

  • 预测值和真实值进行比较 equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
  • 求平均 accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))

6 初始化变量

  • 初始化 init = tf.global_variables_initializer()
  • 在会话中运行 sess.run(init)

7 开启会话

  • 拉取mnist训练集 image, label = mnist_data.train.next_batch(batch_size)
  • 开始训练 _optimizer, loss_value, accuracy_value = sess.run([optimizer, loss, accuracy], feed_dict={X: image, y_true: label})

8 运行效果

running result

9 源代码

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

def fucc_connection(learning_rate=0.01,error_value=np.exp(-5),batch_size=100):
    '''
    这是一个通过全连接网络实现的手写字体识别demo
    :return:
    '''
    ###############################
    # 分析过程
    # 这里的过程和线性回归的过程差不多,无非就是模型和函数有点不一样
    # 1 准备数据
    # 公式:
    # X([None,784]) * weight([784,10]) + bias[10] = y_predict([10])
    # with tf.variable_scope("prepare_data"):
    mnist_data = input_data.read_data_sets("./mnist_data", one_hot=True)
    X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name="X")
    y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10], name="y_true")

    # 2 构造模型
    with tf.variable_scope("create_model"):
        # 参数
        weights = tf.Variable(initial_value=tf.random_normal(shape=[784,10]),name="weight")
        bias = tf.Variable(initial_value=tf.random_normal([10]),name="bias")
        # 模型
        y_predict = tf.matmul(X,weights) + bias

    # 3 构造损失函数
    # 这里使用的损失不再是均方差,这里用的softmax和交叉熵
    with tf.variable_scope("loss_function"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
        # 什么是logits:https://blog.csdn.net/nbxzkok/article/details/84838984

    # 4 优化损失
    # 还是使用梯度下降方法进行优化损失
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

    # 5 计算准确率
    equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))

    ###############
    # 初始化变量
    init = tf.global_variables_initializer()


    # 2) 收集变量
    tf.summary.scalar("loss",loss)
    tf.summary.histogram("weights", weights)
    tf.summary.histogram("bias", bias)

    # 3) 合并变量
    merged = tf.summary.merge_all()

    ## (1)定义一个模型的保存器
    saver = tf.train.Saver()

    # 开启回话
    with tf.Session() as sess:
        sess.run(init)

        # 1) 创建事件
        file_writer = tf.summary.FileWriter(graph=sess.graph, logdir="./mnist_graph")


        image, label = mnist_data.train.next_batch(batch_size)
        print("loss:{}".format(sess.run(loss, feed_dict={X: image, y_true: label})))

        # 开始训练
        # count = 0
        # while(loss.eval() > np.exp(-9)):
        #     count += 1
        #     _optimizer, loss = sess.run([optimizer,loss], feed_dict={X:image, y_true:label})
        #     print("NO.{count},loss:{value}".format(value=loss.eval(), count=count))
        error_value = error_value
        for count in range(100000):
            _optimizer, loss_value, accuracy_value = sess.run([optimizer, loss, accuracy], feed_dict={X: image, y_true: label})
            if accuracy_value < 0.9:
                print("NO.{count},loss:{value}, accuracy:{accuracy}".format(value=loss_value, count=count, accuracy=accuracy_value))
                # 4) 运行合并后的变量
                # summary = sess.run(merged)
                # file_writer.add_summary(summary,count)
                #
                # # (2) 开始保存模型
                # if count % 10 ==0 :
                #     saver.save("./temp/mnist/mnist.ckpt")

            else:
                print("-"*10 + "result" + "-"*10)
                print("learning_rate:{}".format(learning_rate))
                print("minial error value:{}".format(error_value))
                print("accuracy is:{}".format(accuracy_value))
                print("total running times:{} times".format(count))
                break

def main(argv):
    s_time = time.time()
    fucc_connection(learning_rate=0.1,batch_size=1000)
    e_time = time.time()
    print("running time:{} s".format(round((e_time - s_time),2)))


if __name__ == '__main__':
    tf.app.run()


写在最后

欢迎大家关注鄙人的公众号【麦田里的守望者zhg】,让我们一起成长,谢谢。 微信公众号个人博客