自编码器-mnist-cnn

284 阅读3分钟
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
tf.__version__
'1.10.0'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
WARNING:tensorflow:From <ipython-input-3-1b917109a13c>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
img = mnist.train.images[30]
plt.imshow(img.reshape((28,28)))
<matplotlib.image.AxesImage at 0x1c3acafe80>

mnist.train.images.shape
(55000, 784)
inputs = tf.placeholder(tf.float32, (None, 28, 28, 1), name = 'inputs')
targets = tf.placeholder(tf.float32, (None, 28, 28, 1), name = 'targets')
conv1 = tf.layers.conv2d(inputs, 32, (3,3), padding='same', activation=tf.nn.relu)
conv1 = tf.layers.max_pooling2d(conv1, (2, 2), (2, 2), padding='same')
conv2 = tf.layers.conv2d(conv1, 64, (3,3), padding='same', activation=tf.nn.relu)
conv2 = tf.layers.max_pooling2d(conv2, (2, 2), (2, 2), padding='same')
conv3 = tf.layers.conv2d(conv2, 64, (3,3), padding='same', activation=tf.nn.relu)
conv3 = tf.layers.max_pooling2d(conv3, (2, 2), (2, 2), padding='same')
conv3.get_shape()
TensorShape([Dimension(None), Dimension(4), Dimension(4), Dimension(64)])
conv4 = tf.layers.conv2d_transpose(conv3, 32, (4, 4), strides=(1, 1), padding='valid')
conv4.get_shape()
TensorShape([Dimension(None), Dimension(7), Dimension(7), Dimension(32)])
conv5 = tf.layers.conv2d_transpose(conv4, 16, (2, 2), strides=(2, 2), padding='same')
conv6 = tf.layers.conv2d_transpose(conv5, 1, (2, 2), strides=(2, 2), padding='same')
conv6.get_shape()
TensorShape([Dimension(None), Dimension(28), Dimension(28), Dimension(1)])
loss = tf.reduce_sum(tf.square(conv6 - targets))
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
epochs = 50
batch_size = 128
scale = 0.6
for i in range(epochs):
    epoch_loss = []
    for x in range(mnist.train.num_examples//batch_size):
        batch = mnist.train.next_batch(batch_size)
        imgs = batch[0].reshape((-1, 28, 28, 1))
        batch_cost, _ = sess.run([loss, optimizer], 
                                 feed_dict = {inputs:imgs + scale*np.random.normal(size = imgs.shape),targets:imgs})
        epoch_loss.append(batch_cost)
    print("Epoch: {}/{}".format(i + 1, epochs),
         "Training loss: {:.4f}".format(sum(epoch_loss)/len(epoch_loss)))

Epoch: 1/50 Training loss: 4159.0466
Epoch: 2/50 Training loss: 2508.3305
Epoch: 3/50 Training loss: 2245.7472



---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

<ipython-input-20-16b6796bb00d> in <module>()
      5         imgs = batch[0].reshape((-1, 28, 28, 1))
      6         batch_cost, _ = sess.run([loss, optimizer], 
----> 7                                  feed_dict = {inputs:imgs + scale*np.random.normal(size = imgs.shape),targets:imgs})
      8         epoch_loss.append(batch_cost)
      9     print("Epoch: {}/{}".format(i + 1, epochs),


/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    875     try:
    876       result = self._run(None, fetches, feed_dict, options_ptr,
--> 877                          run_metadata_ptr)
    878       if run_metadata:
    879         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)


/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1098     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1099       results = self._do_run(handle, final_targets, final_fetches,
-> 1100                              feed_dict_tensor, options, run_metadata)
   1101     else:
   1102       results = []


/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1270     if handle is None:
   1271       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1272                            run_metadata)
   1273     else:
   1274       return self._do_call(_prun_fn, handle, feeds, fetches)


/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1276   def _do_call(self, fn, *args):
   1277     try:
-> 1278       return fn(*args)
   1279     except errors.OpError as e:
   1280       message = compat.as_text(e.message)


/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1261       self._extend_graph()
   1262       return self._call_tf_sessionrun(
-> 1263           options, feed_dict, fetch_list, target_list, run_metadata)
   1264 
   1265     def _prun_fn(handle, feed_dict, fetch_list):


/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
   1348     return tf_session.TF_SessionRun_wrapper(
   1349         self._session, options, feed_dict, fetch_list, target_list,
-> 1350         run_metadata)
   1351 
   1352   def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):


KeyboardInterrupt: 
fig, ax = plt.subplots(nrows = 2, ncols = 5, sharex =True, sharey = True, figsize =(20,8))
test_imgs = mnist.test.images[:5].reshape((-1, 28, 28, 1))
reconstructed = sess.run(conv6, feed_dict = {inputs: test_imgs})
test_imgs_with_noist = test_imgs + scale*np.random.normal(size = test_imgs.shape)
for image, row in zip([test_imgs_with_noist, reconstructed], ax):
    for img, ax in zip(image, row):
        ax.imshow(img.reshape(28, 28))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
fig.tight_layout(pad = 0.1)