本文已参与「新人创作礼」活动,一起开启掘金创作之路。
0 引言
昨天我做(学者网上的教程)了一个线性回归的模型,可以参考这篇博客,用的TensorFlow框架,今天我继续学习,用TensorFlow框架对mnist数据集进行手写体识别。
1 准备数据
这里用到的是TensorFlow里面的placeholder占位符,类似constant,只不过先定义但是不赋值,用起来的时候再赋值。
- mnist数据集
from tensorflow.examples.tutorials.mnist import input_data mnist_data = input_data.read_data_sets("./mnist_data", one_hot=True) - y_true
y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10], name="y_true") - label
X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name="X")
2 构造模型
- 参数
weights = tf.Variable(initial_value=tf.random_normal(shape=[784,10]),name="weight") bias = tf.Variable(initial_value=tf.random_normal([10]),name="bias") - 模型
y_predict = tf.matmul(X,weights) + bias
3 构造损失函数
- loss function loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
4 优化损失
- 梯度下降法 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
5 计算准确率
- 预测值和真实值进行比较 equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
- 求平均 accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
6 初始化变量
- 初始化 init = tf.global_variables_initializer()
- 在会话中运行 sess.run(init)
7 开启会话
- 拉取mnist训练集 image, label = mnist_data.train.next_batch(batch_size)
- 开始训练 _optimizer, loss_value, accuracy_value = sess.run([optimizer, loss, accuracy], feed_dict={X: image, y_true: label})
8 运行效果
9 源代码
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
def fucc_connection(learning_rate=0.01,error_value=np.exp(-5),batch_size=100):
'''
这是一个通过全连接网络实现的手写字体识别demo
:return:
'''
###############################
# 分析过程
# 这里的过程和线性回归的过程差不多,无非就是模型和函数有点不一样
# 1 准备数据
# 公式:
# X([None,784]) * weight([784,10]) + bias[10] = y_predict([10])
# with tf.variable_scope("prepare_data"):
mnist_data = input_data.read_data_sets("./mnist_data", one_hot=True)
X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name="X")
y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10], name="y_true")
# 2 构造模型
with tf.variable_scope("create_model"):
# 参数
weights = tf.Variable(initial_value=tf.random_normal(shape=[784,10]),name="weight")
bias = tf.Variable(initial_value=tf.random_normal([10]),name="bias")
# 模型
y_predict = tf.matmul(X,weights) + bias
# 3 构造损失函数
# 这里使用的损失不再是均方差,这里用的softmax和交叉熵
with tf.variable_scope("loss_function"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
# 什么是logits:https://blog.csdn.net/nbxzkok/article/details/84838984
# 4 优化损失
# 还是使用梯度下降方法进行优化损失
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
# 5 计算准确率
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
###############
# 初始化变量
init = tf.global_variables_initializer()
# 2) 收集变量
tf.summary.scalar("loss",loss)
tf.summary.histogram("weights", weights)
tf.summary.histogram("bias", bias)
# 3) 合并变量
merged = tf.summary.merge_all()
## (1)定义一个模型的保存器
saver = tf.train.Saver()
# 开启回话
with tf.Session() as sess:
sess.run(init)
# 1) 创建事件
file_writer = tf.summary.FileWriter(graph=sess.graph, logdir="./mnist_graph")
image, label = mnist_data.train.next_batch(batch_size)
print("loss:{}".format(sess.run(loss, feed_dict={X: image, y_true: label})))
# 开始训练
# count = 0
# while(loss.eval() > np.exp(-9)):
# count += 1
# _optimizer, loss = sess.run([optimizer,loss], feed_dict={X:image, y_true:label})
# print("NO.{count},loss:{value}".format(value=loss.eval(), count=count))
error_value = error_value
for count in range(100000):
_optimizer, loss_value, accuracy_value = sess.run([optimizer, loss, accuracy], feed_dict={X: image, y_true: label})
if accuracy_value < 0.9:
print("NO.{count},loss:{value}, accuracy:{accuracy}".format(value=loss_value, count=count, accuracy=accuracy_value))
# 4) 运行合并后的变量
# summary = sess.run(merged)
# file_writer.add_summary(summary,count)
#
# # (2) 开始保存模型
# if count % 10 ==0 :
# saver.save("./temp/mnist/mnist.ckpt")
else:
print("-"*10 + "result" + "-"*10)
print("learning_rate:{}".format(learning_rate))
print("minial error value:{}".format(error_value))
print("accuracy is:{}".format(accuracy_value))
print("total running times:{} times".format(count))
break
def main(argv):
s_time = time.time()
fucc_connection(learning_rate=0.1,batch_size=1000)
e_time = time.time()
print("running time:{} s".format(round((e_time - s_time),2)))
if __name__ == '__main__':
tf.app.run()