import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt
%matplotlib inline
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
WARNING:tensorflow:From <ipython-input-2-93d8da72a918>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
real_img = tf.placeholder(tf.float32, [None, 784], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, 100], name='noise_img')
def generator(noise_img, hidden_units, out_dim, reuse=False):
with tf.variable_scope("generator", reuse = reuse):
hidden1 = tf.layers.dense(noise_img, hidden_units)
hidden1 = tf.nn.relu(hidden1)
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.nn.sigmoid(logits)
return logits, outputs
def discriminator(img, hidden_units, reuse=False):
with tf.variable_scope("discriminator", reuse=reuse):
hidden1 = tf.layers.dense(img, hidden_units)
hidden1 = tf.nn.relu(hidden1)
outputs = tf.layers.dense(hidden1, 1)
return outputs
def plot_images(samples):
samples = (samples +1)/2
fig, axes = plt.subplots(nrows =1, ncols=25, sharex = True, sharey = True, figsize =(50, 2))
for img, ax in zip(samples, axes):
ax.imshow(img.reshape(28, 28), cmap='Greys_r')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad = 0)
img_size = 784
noise_size = 100
hidden_units = 128
alpha = 0.01
learning_rate = 0.0001
smooth = 0.1
g_logits, g_outputs = generator(noise_img, hidden_units, img_size)
d_real = discriminator(real_img, hidden_units)
d_fake = discriminator(g_outputs, hidden_units, reuse=True)
d_loss = tf.reduce_mean(d_real) - tf.reduce_mean(d_fake)
g_loss = -tf.reduce_mean(d_fake)
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(-d_loss, var_list=d_vars)
g_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
clip_d = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars]
batch_size = 32
n_sample = 25
samples = []
losses = []
saver = tf.train.Saver(var_list = g_vars)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for it in range(1000000):
for _ in range(5):
batch_images, _ = mnist.train.next_batch(batch_size)
batch_noise = np.random.uniform(-1, 1, size = (batch_size, noise_size))
_ = sess.run([d_train_opt, clip_d], feed_dict = {real_img: batch_images, noise_img:batch_noise})
_ = sess.run(g_train_opt, feed_dict = {noise_img: batch_noise})
if it%1000 == 0:
sample_noise = np.random.uniform(-1, 1, size =(n_sample, noise_size))
_, samples = sess.run(generator(noise_img, hidden_units, img_size, reuse = True),
feed_dict={noise_img: sample_noise})
plot_images(samples)
train_loss_d = sess.run(d_loss,
feed_dict = {real_img:batch_images,
noise_img:batch_noise})
train_loss_g = sess.run(g_loss,
feed_dict = {noise_img: batch_noise})
print("Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g))
losses.append((train_loss_d, train_loss_g))
saver.save(sess, './checkpoints/generator.ckpt')