WGAN

272 阅读1分钟
import tensorflow as tf
import numpy as np
import pickle 
import matplotlib.pyplot as plt
%matplotlib inline
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
WARNING:tensorflow:From <ipython-input-2-93d8da72a918>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
real_img = tf.placeholder(tf.float32, [None, 784], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, 100], name='noise_img')

def generator(noise_img, hidden_units, out_dim, reuse=False):
    with tf.variable_scope("generator", reuse = reuse):
        hidden1 = tf.layers.dense(noise_img, hidden_units)
        hidden1 = tf.nn.relu(hidden1)
        #hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.nn.sigmoid(logits)
        
        return logits, outputs
def discriminator(img, hidden_units, reuse=False):
    with tf.variable_scope("discriminator", reuse=reuse):
        hidden1 = tf.layers.dense(img, hidden_units)
        #hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.nn.relu(hidden1)
        
        outputs = tf.layers.dense(hidden1, 1)
        #outputs = tf.sigmoid(logits)
        
        return outputs
def plot_images(samples):
    samples = (samples +1)/2
    fig, axes = plt.subplots(nrows =1, ncols=25, sharex = True, sharey = True, figsize =(50, 2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape(28, 28), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad = 0)
img_size = 784
noise_size = 100
hidden_units = 128
alpha = 0.01
learning_rate = 0.0001
smooth = 0.1
g_logits, g_outputs = generator(noise_img, hidden_units, img_size)

d_real = discriminator(real_img, hidden_units)
d_fake = discriminator(g_outputs, hidden_units, reuse=True)
d_loss = tf.reduce_mean(d_real) - tf.reduce_mean(d_fake)
g_loss = -tf.reduce_mean(d_fake)
train_vars = tf.trainable_variables()

g_vars = [var for var in train_vars if var.name.startswith("generator")]

d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

d_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(-d_loss, var_list=d_vars)
g_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

clip_d = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars]
batch_size = 32
n_sample = 25
samples = []
losses = []
saver = tf.train.Saver(var_list = g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for it in range(1000000):
        for _ in range(5):
            batch_images, _ = mnist.train.next_batch(batch_size)
            
            batch_noise = np.random.uniform(-1, 1, size = (batch_size, noise_size))
            
            _ = sess.run([d_train_opt, clip_d], feed_dict = {real_img: batch_images, noise_img:batch_noise})
        _ = sess.run(g_train_opt, feed_dict = {noise_img: batch_noise})
        if it%1000 == 0:
            sample_noise = np.random.uniform(-1, 1, size =(n_sample, noise_size))
            _, samples = sess.run(generator(noise_img, hidden_units, img_size, reuse = True),
                                 feed_dict={noise_img: sample_noise})
            plot_images(samples)
            train_loss_d = sess.run(d_loss, 
                                   feed_dict = {real_img:batch_images,
                                               noise_img:batch_noise})


            train_loss_g = sess.run(g_loss, 
                                   feed_dict = {noise_img: batch_noise})

            print("Discriminator Loss: {:.4f}...".format(train_loss_d),
                  "Generator Loss: {:.4f}".format(train_loss_g))

            losses.append((train_loss_d, train_loss_g))

            saver.save(sess, './checkpoints/generator.ckpt')