Tensorflow minist-softmax

189 阅读3分钟
import numpy as np
import tensorflow as tf
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: compiletime version 3.6 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.5
  return f(*args, **kwds)
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
len(mnist.train.images), len(mnist.train.labels)
(55000, 55000)
len(mnist.test.images), len(mnist.test.labels)
(10000, 10000)
mnist.train.images[0]
array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.3803922 , 0.37647063, 0.3019608 ,
       0.46274513, 0.2392157 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.3529412 , 0.5411765 , 0.9215687 ,
       0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 ,
       0.9843138 , 0.9843138 , 0.9725491 , 0.9960785 , 0.9607844 ,
       0.9215687 , 0.74509805, 0.08235294, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.54901963,
       0.9843138 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.7411765 , 0.09019608, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.8862746 , 0.9960785 , 0.81568635,
       0.7803922 , 0.7803922 , 0.7803922 , 0.7803922 , 0.54509807,
       0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 ,
       0.5019608 , 0.8705883 , 0.9960785 , 0.9960785 , 0.7411765 ,
       0.08235294, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.14901961, 0.32156864, 0.0509804 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.13333334,
       0.8352942 , 0.9960785 , 0.9960785 , 0.45098042, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.32941177, 0.9960785 ,
       0.9960785 , 0.9176471 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.32941177, 0.9960785 , 0.9960785 , 0.9176471 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.4156863 , 0.6156863 ,
       0.9960785 , 0.9960785 , 0.95294124, 0.20000002, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.09803922, 0.45882356, 0.8941177 , 0.8941177 ,
       0.8941177 , 0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.94117653, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.26666668, 0.4666667 , 0.86274517,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.5568628 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.14509805, 0.73333335,
       0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 , 0.8745099 ,
       0.8078432 , 0.8078432 , 0.29411766, 0.26666668, 0.8431373 ,
       0.9960785 , 0.9960785 , 0.45882356, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.4431373 , 0.8588236 , 0.9960785 , 0.9490197 , 0.89019614,
       0.45098042, 0.34901962, 0.12156864, 0.        , 0.        ,
       0.        , 0.        , 0.7843138 , 0.9960785 , 0.9450981 ,
       0.16078432, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.6627451 , 0.9960785 ,
       0.6901961 , 0.24313727, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.18823531,
       0.9058824 , 0.9960785 , 0.9176471 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.07058824, 0.48627454, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.32941177, 0.9960785 , 0.9960785 ,
       0.6509804 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.54509807, 0.9960785 , 0.9333334 , 0.22352943, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.8235295 , 0.9803922 , 0.9960785 ,
       0.65882355, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.9490197 , 0.9960785 , 0.93725497, 0.22352943, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.34901962, 0.9843138 , 0.9450981 ,
       0.3372549 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.01960784,
       0.8078432 , 0.96470594, 0.6156863 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.01568628, 0.45882356, 0.27058825,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ], dtype=float32)
len(mnist.train.images[0])
784
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(mnist.train.images[1].reshape(28,28))
<matplotlib.image.AxesImage at 0x1c27a1a550>

mnist.train.labels[1]
array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
x = tf.placeholder("float", shape=[None, 784])
y = tf.placeholder("float", shape=[None, 10])
weight = tf.Variable(tf.truncated_normal([784,10]))
bias = tf.Variable(tf.truncated_normal([10]))
combine_input = tf.matmul(x, weight) + bias
pred = tf.nn.softmax(combine_input)
loss = -tf.reduce_sum(y * tf.log(pred))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(1100):
    batch = mnist.train.next_batch(50)
    sess.run(train_step, feed_dict={x : batch[0], y:batch[1]})
    if i%50 == 0:
            print(sess.run(loss, feed_dict={x : batch[0], y:batch[1]}))
329.698
114.416695
38.314323
20.213478
48.926674
26.53627
28.653086
43.464195
16.75724
39.731388
12.251608
32.379055
24.371075
18.137915
8.972845
24.207663
29.931976
7.5475473
10.576719
28.017235
14.364228
11.022556
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print(acc)
0.8813