Tensorflow minist-cnn

217 阅读3分钟
import numpy as np
import tensorflow as tf
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: compiletime version 3.6 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.5
  return f(*args, **kwds)
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
len(mnist.train.images), len(mnist.train.labels)
(55000, 55000)
len(mnist.test.images), len(mnist.test.labels)
(10000, 10000)
mnist.train.images[0]
array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.3803922 , 0.37647063, 0.3019608 ,
       0.46274513, 0.2392157 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.3529412 , 0.5411765 , 0.9215687 ,
       0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 ,
       0.9843138 , 0.9843138 , 0.9725491 , 0.9960785 , 0.9607844 ,
       0.9215687 , 0.74509805, 0.08235294, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.54901963,
       0.9843138 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.7411765 , 0.09019608, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.8862746 , 0.9960785 , 0.81568635,
       0.7803922 , 0.7803922 , 0.7803922 , 0.7803922 , 0.54509807,
       0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 ,
       0.5019608 , 0.8705883 , 0.9960785 , 0.9960785 , 0.7411765 ,
       0.08235294, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.14901961, 0.32156864, 0.0509804 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.13333334,
       0.8352942 , 0.9960785 , 0.9960785 , 0.45098042, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.32941177, 0.9960785 ,
       0.9960785 , 0.9176471 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.32941177, 0.9960785 , 0.9960785 , 0.9176471 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.4156863 , 0.6156863 ,
       0.9960785 , 0.9960785 , 0.95294124, 0.20000002, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.09803922, 0.45882356, 0.8941177 , 0.8941177 ,
       0.8941177 , 0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.94117653, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.26666668, 0.4666667 , 0.86274517,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.5568628 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.14509805, 0.73333335,
       0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 , 0.8745099 ,
       0.8078432 , 0.8078432 , 0.29411766, 0.26666668, 0.8431373 ,
       0.9960785 , 0.9960785 , 0.45882356, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.4431373 , 0.8588236 , 0.9960785 , 0.9490197 , 0.89019614,
       0.45098042, 0.34901962, 0.12156864, 0.        , 0.        ,
       0.        , 0.        , 0.7843138 , 0.9960785 , 0.9450981 ,
       0.16078432, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.6627451 , 0.9960785 ,
       0.6901961 , 0.24313727, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.18823531,
       0.9058824 , 0.9960785 , 0.9176471 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.07058824, 0.48627454, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.32941177, 0.9960785 , 0.9960785 ,
       0.6509804 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.54509807, 0.9960785 , 0.9333334 , 0.22352943, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.8235295 , 0.9803922 , 0.9960785 ,
       0.65882355, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.9490197 , 0.9960785 , 0.93725497, 0.22352943, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.34901962, 0.9843138 , 0.9450981 ,
       0.3372549 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.01960784,
       0.8078432 , 0.96470594, 0.6156863 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.01568628, 0.45882356, 0.27058825,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ], dtype=float32)
len(mnist.train.images[0])
784
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(mnist.train.images[1].reshape(28,28))
<matplotlib.image.AxesImage at 0x1c281284a8>

mnist.train.labels[1]
array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])
conv2d_1 = tf.contrib.layers.convolution2d(
    x_image,
    num_outputs=32,
    kernel_size=(5,5),
    activation_fn = tf.nn.relu,
    stride = (1,1),
    padding = 'SAME',
    trainable = True)
pool_1 = tf.nn.max_pool(conv2d_1, 
                        ksize= [1, 2,2,1],
                       strides = [1,2,2,1],
                       padding='SAME')
conv2d_1.get_shape()
TensorShape([Dimension(None), Dimension(28), Dimension(28), Dimension(32)])
pool_1.get_shape()
TensorShape([Dimension(None), Dimension(14), Dimension(14), Dimension(32)])
conv2d_2 = tf.contrib.layers.convolution2d(
    pool_1,
    num_outputs=64,
    kernel_size=(5,5),
    activation_fn = tf.nn.relu,
    stride = (1,1),
    padding = 'SAME',
    trainable = True)
pool_2 = tf.nn.max_pool(conv2d_2, 
                        ksize= [1, 2,2,1],
                       strides = [1,2,2,1],
                       padding='SAME')
conv2d_2.get_shape()
TensorShape([Dimension(None), Dimension(14), Dimension(14), Dimension(64)])
pool_2.get_shape()
TensorShape([Dimension(None), Dimension(7), Dimension(7), Dimension(64)])
pool2_flat = tf.reshape(pool_2, [-1, 7*7*64])
fc_1 = tf.contrib.layers.fully_connected(pool2_flat, 
                                        1024,
                                        activation_fn = tf.nn.relu)
keep_prob = tf.placeholder("float")
fc1_drop = tf.nn.dropout(fc_1, keep_prob)
fc_2 = tf.contrib.layers.fully_connected(fc1_drop, 
                                        10,
                                        activation_fn = tf.nn.softmax)
loss = -tf.reduce_sum(y_ * tf.log(fc_2))
train_step = tf.train.GradientDescentOptimizer(0.0001).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
correct_pred = tf.equal(tf.argmax(fc_2, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
for i in range(20000):
    batch = mnist.train.next_batch(50)
    sess.run(train_step, feed_dict={x : batch[0], y_:batch[1], keep_prob:0.5})
    if i%50 == 0:
        print(sess.run(loss, feed_dict={x : batch[0], y_:batch[1], keep_prob:1}))
        print(sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels,keep_prob:1}))
        print("--------------")
114.755585
0.0773
--------------
111.51022
0.3186
--------------
109.173744
0.5378
--------------
98.06425
0.6548
--------------
81.93743
0.6982
--------------
46.132698
0.7407
--------------
43.139786
0.7874
--------------
31.82726
0.8418
--------------
30.005209
0.8573
--------------
29.255384
0.8752
--------------
26.02791
0.8747
--------------
18.284384
0.8856
--------------
28.066277
0.8984
--------------
24.184027
0.8982
--------------
13.036492
0.9023
--------------
33.22979
0.9084
--------------
13.805885
0.9051
--------------
13.716793
0.9159
--------------
9.310038
0.914
--------------
11.859652
0.92
--------------
14.794542
0.9246
--------------
9.77744
0.9232
--------------
6.099699
0.9258
--------------
15.665421
0.9189
--------------
10.991343
0.9273
--------------
9.812216
0.9307
--------------
7.8547688
0.9357
--------------
12.281435
0.9342
--------------
7.824145
0.9326
--------------
7.7990685
0.9379
--------------
9.309114
0.9373
--------------
11.758891
0.9378
--------------
14.695442
0.9354
--------------
12.599428
0.9403
--------------
15.817507
0.942
--------------
8.946287
0.9429
--------------
5.845645
0.9443
--------------
7.3763533
0.9458
--------------
5.3192654
0.9425
--------------
21.371431
0.946
--------------
10.627939
0.9446
--------------
9.776442
0.9485
--------------
8.267487
0.9509
--------------
10.85527
0.9495
--------------
7.1606274
0.9487
--------------
9.922831
0.9534
--------------
11.272451
0.9522
--------------
5.5952454
0.9512
--------------
3.8621898
0.9527
--------------
7.7660904
0.9532
--------------
12.440466
0.9552
--------------
4.285364
0.956
--------------
6.494634
0.9549
--------------
5.1371355
0.9559
--------------
4.0771656
0.9579
--------------
3.8677125
0.9595
--------------
10.749643
0.9578
--------------
4.4979386
0.9564
--------------
5.699411
0.9567
--------------
3.6296415
0.9576
--------------
5.879737
0.9604
--------------
15.613115
0.9591
--------------
10.129694
0.9575
--------------
2.5639577
0.9606
--------------
4.8280177
0.9612
--------------
8.217493
0.9622
--------------
7.9065475
0.9616
--------------
3.5210996
0.9647
--------------
5.08947
0.9638
--------------
3.2810946
0.9628
--------------
4.441616
0.9622
--------------
8.599348
0.9623
--------------
8.364786
0.9629
--------------
2.6642575
0.9635
--------------
7.6038284
0.9643
--------------
5.4900303
0.965
--------------
6.026716
0.9644
--------------
9.241722
0.9669
--------------
1.4805893
0.9658
--------------
2.4804783
0.9656
--------------
2.4027157
0.9664
--------------
1.2007334
0.9654
--------------
3.7446694
0.966
--------------
3.156123
0.9652
--------------
2.7786171
0.9667
--------------
1.9748433
0.9682
--------------
10.901329
0.968
--------------
4.1280313
0.968
--------------
8.042549
0.9685
--------------
4.606745
0.9642
--------------
2.513884
0.9684
--------------
1.15321
0.9686
--------------
6.9821544
0.9679
--------------
2.0118175
0.9687
--------------
4.721838
0.9683
--------------
3.367778
0.9691
--------------
2.384357
0.969
--------------
2.3209221
0.9684
--------------
4.941422
0.9697
--------------
6.2601924
0.9697
--------------
7.7655964
0.9695
--------------
2.4025197
0.9696
--------------
2.9681091
0.9692
--------------
1.9564868
0.9718
--------------
2.998015
0.9712
--------------
1.7640419
0.9715
--------------
2.1289375
0.971
--------------
4.4494224
0.9706
--------------
2.3486629
0.9713
--------------
3.0480092
0.9711
--------------
4.4215302
0.9708
--------------
1.962319
0.9722
--------------
3.0561228
0.9732
--------------
1.406903
0.9718
--------------
10.200954
0.973
--------------
6.6716814
0.9724
--------------
1.1285647
0.9732
--------------
2.3097482
0.9739
--------------
3.3002224
0.9734
--------------
13.328636
0.9739
--------------
3.0631943
0.9727
--------------
2.1633306
0.9744
--------------
2.199626
0.9746
--------------
3.5341806
0.9749
--------------
2.6506147
0.9727
--------------
2.4559627
0.9743
--------------
2.8384607
0.9711
--------------
2.4524465
0.974
--------------
1.9630153