自编码器-mnist-fullyconnected

314 阅读4分钟
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  return f(*args, **kwds)
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  return f(*args, **kwds)
tf.__version__
'1.10.0'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
WARNING:tensorflow:From <ipython-input-4-1b917109a13c>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
img = mnist.train.images[30]
plt.imshow(img.reshape((28,28)))
<matplotlib.image.AxesImage at 0x1c3816ba20>

mnist.train.images.shape
(55000, 784)
hidden_units = 64
input_units = mnist.train.images.shape[1]
inputs = tf.placeholder(tf.float32, (None, input_units), name = 'inputs')
targets = tf.placeholder(tf.float32, (None, input_units), name = 'targets')
hidden = tf.layers.dense(inputs, hidden_units, activation=tf.nn.relu)
logits = tf.layers.dense(hidden, input_units, activation=None)
outputs = tf.sigmoid(logits, name ='outputs')
loss = tf.reduce_sum(tf.square(outputs - targets))
optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
epochs = 100
batch_size = 128
scale = 0.6
for i in range(epochs):
    epoch_loss = []
    for x in range(mnist.train.num_examples//batch_size):
        batch = mnist.train.next_batch(batch_size)
        batch_cost, _ = sess.run([loss, optimizer], 
                                 feed_dict= { inputs: batch[0] + scale*np.random.normal(size = batch[0].shape),
                                            targets: batch[0]})
        epoch_loss.append(batch_cost)
    print("Epoch: {}/{}".format(i + 1, epochs),
         "Training loss: {:.4f}".format(sum(epoch_loss)/len(epoch_loss)))
Epoch: 1/100 Training loss: 4122.6152
Epoch: 2/100 Training loss: 3175.5261
Epoch: 3/100 Training loss: 2949.5105
Epoch: 4/100 Training loss: 2785.7994
Epoch: 5/100 Training loss: 2674.1762
Epoch: 6/100 Training loss: 2587.7063
Epoch: 7/100 Training loss: 2494.3316
Epoch: 8/100 Training loss: 2412.4359
Epoch: 9/100 Training loss: 2370.8783
Epoch: 10/100 Training loss: 2347.2669
Epoch: 11/100 Training loss: 2323.5621
Epoch: 12/100 Training loss: 2314.9729
Epoch: 13/100 Training loss: 2303.4539
Epoch: 14/100 Training loss: 2298.0133
Epoch: 15/100 Training loss: 2292.6705
Epoch: 16/100 Training loss: 2274.9722
Epoch: 17/100 Training loss: 2270.7350
Epoch: 18/100 Training loss: 2259.5915
Epoch: 19/100 Training loss: 2253.6038
Epoch: 20/100 Training loss: 2247.7472
Epoch: 21/100 Training loss: 2248.0725
Epoch: 22/100 Training loss: 2240.8025
Epoch: 23/100 Training loss: 2238.4714
Epoch: 24/100 Training loss: 2236.3363
Epoch: 25/100 Training loss: 2233.4593
Epoch: 26/100 Training loss: 2227.6890
Epoch: 27/100 Training loss: 2233.3891
Epoch: 28/100 Training loss: 2228.8210
Epoch: 29/100 Training loss: 2232.7463
Epoch: 30/100 Training loss: 2224.1135
Epoch: 31/100 Training loss: 2220.9270
Epoch: 32/100 Training loss: 2227.0465
Epoch: 33/100 Training loss: 2222.6370
Epoch: 34/100 Training loss: 2213.9390
Epoch: 35/100 Training loss: 2223.2059
Epoch: 36/100 Training loss: 2219.7526
Epoch: 37/100 Training loss: 2216.6488
Epoch: 38/100 Training loss: 2218.4137
Epoch: 39/100 Training loss: 2215.9420
Epoch: 40/100 Training loss: 2218.2955
Epoch: 41/100 Training loss: 2216.0744
Epoch: 42/100 Training loss: 2215.5212
Epoch: 43/100 Training loss: 2214.1895
Epoch: 44/100 Training loss: 2213.8964
Epoch: 45/100 Training loss: 2214.0576
Epoch: 46/100 Training loss: 2214.3030
Epoch: 47/100 Training loss: 2216.1043
Epoch: 48/100 Training loss: 2215.5620
Epoch: 49/100 Training loss: 2218.0237
Epoch: 50/100 Training loss: 2205.8529
Epoch: 51/100 Training loss: 2211.3229
Epoch: 52/100 Training loss: 2211.8879
Epoch: 53/100 Training loss: 2211.6150
Epoch: 54/100 Training loss: 2212.2388
Epoch: 55/100 Training loss: 2214.1972
Epoch: 56/100 Training loss: 2209.8900
Epoch: 57/100 Training loss: 2206.5484
Epoch: 58/100 Training loss: 2210.9500
Epoch: 59/100 Training loss: 2213.4333
Epoch: 60/100 Training loss: 2209.0800
Epoch: 61/100 Training loss: 2206.9485
Epoch: 62/100 Training loss: 2214.8776
Epoch: 63/100 Training loss: 2207.4935
Epoch: 64/100 Training loss: 2206.9996
Epoch: 65/100 Training loss: 2211.6217
Epoch: 66/100 Training loss: 2214.4954
Epoch: 67/100 Training loss: 2208.1563
Epoch: 68/100 Training loss: 2209.4227
Epoch: 69/100 Training loss: 2210.2614
Epoch: 70/100 Training loss: 2206.8853
Epoch: 71/100 Training loss: 2211.4665
Epoch: 72/100 Training loss: 2213.6993
Epoch: 73/100 Training loss: 2208.0797
Epoch: 74/100 Training loss: 2207.9999
Epoch: 75/100 Training loss: 2211.5181
Epoch: 76/100 Training loss: 2205.4149
Epoch: 77/100 Training loss: 2211.1890
Epoch: 78/100 Training loss: 2201.1221
Epoch: 79/100 Training loss: 2209.5938
Epoch: 80/100 Training loss: 2206.1002
Epoch: 81/100 Training loss: 2205.2419
Epoch: 82/100 Training loss: 2206.9217
Epoch: 83/100 Training loss: 2209.1512
Epoch: 84/100 Training loss: 2204.2429
Epoch: 85/100 Training loss: 2205.7824
Epoch: 86/100 Training loss: 2208.6107
Epoch: 87/100 Training loss: 2205.6961
Epoch: 88/100 Training loss: 2204.8022
Epoch: 89/100 Training loss: 2206.0642
Epoch: 90/100 Training loss: 2204.2363
Epoch: 91/100 Training loss: 2201.9248
Epoch: 92/100 Training loss: 2204.6733
Epoch: 93/100 Training loss: 2206.9835
Epoch: 94/100 Training loss: 2204.4014
Epoch: 95/100 Training loss: 2202.2799
Epoch: 96/100 Training loss: 2209.2330
Epoch: 97/100 Training loss: 2206.6406
Epoch: 98/100 Training loss: 2201.3019
Epoch: 99/100 Training loss: 2203.5554
Epoch: 100/100 Training loss: 2201.1172
fig, ax = plt.subplots(nrows = 2, ncols = 5, sharex =True, sharey = True, figsize =(20,8))
test_imgs = mnist.test.images[:5]
test_imgs_with_noist = test_imgs + scale*np.random.normal(size = test_imgs.shape)
reconstructed, compressed = sess.run([outputs, hidden], feed_dict = {inputs: test_imgs_with_noist})
for image, row in zip([test_imgs_with_noist, reconstructed], ax):
    for img, ax in zip(image, row):
        ax.imshow(img.reshape(28, 28))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
fig.tight_layout(pad = 0.1)

fig, axes = plt.subplots(nrows = 1, ncols = 5, sharex = True, sharey = True, figsize = (20, 4))
for img, ax in zip(compressed, axes):
    ax.imshow(img.reshape(8,8))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
fig.tight_layout(pad = 0)