全连接gan

210 阅读2分钟
import tensorflow as tf
import numpy as np
import pickle 
import matplotlib.pyplot as plt
%matplotlib inline
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
WARNING:tensorflow:From <ipython-input-2-1b917109a13c>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
real_img = tf.placeholder(tf.float32, [None, 784], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, 100], name='noise_img')

def generator(noise_img, hidden_units, out_dim, reuse=False):
    with tf.variable_scope("generator", reuse = reuse):
        hidden1 = tf.layers.dense(noise_img, hidden_units)
        hidden1 = tf.nn.relu(hidden1)
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        
        return logits, outputs
def discriminator(img, hidden_units, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        hidden1 = tf.layers.dense(img, hidden_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs
def plot_images(samples):
    samples = (samples +1)/2
    fig, axes = plt.subplots(nrows =1, ncols=25, sharex = True, sharey = True, figsize =(50, 2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape(28, 28), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad = 0)
img_size = 784
noise_size = 100
hidden_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
g_logits, g_outputs = generator(noise_img, hidden_units, img_size)

d_logits_real, d_outputs_real = discriminator(real_img, hidden_units)
d_logits_fake, d_outputs_fake = discriminator(g_outputs, hidden_units, reuse=True)
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_real,
                                                                    labels = tf.ones_like(d_logits_real)) * (1 - smooth))

d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                                    labels = tf.zeros_like(d_logits_fake)))

d_loss = tf.add(d_loss_real, d_loss_fake)

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                               labels = tf.ones_like(d_logits_fake)) * (1 - smooth))
train_vars = tf.trainable_variables()

g_vars = [var for var in train_vars if var.name.startswith("generator")]

d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

batch_size = 64
epochs = 300
n_sample = 25
samples = []
losses = []
saver = tf.train.Saver(var_list = g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for batch_i in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 -1
            
            batch_noise = np.random.uniform(-1, 1, size = (batch_size, noise_size))
            
            _ = sess.run(d_train_opt, feed_dict = {real_img: batch_images, noise_img:batch_noise})
            _ = sess.run(g_train_opt, feed_dict = {noise_img: batch_noise})
        if e%30 == 0:
            sample_noise = np.random.uniform(-1, 1, size =(n_sample, noise_size))
            _, samples = sess.run(generator(noise_img, hidden_units, img_size, reuse = True),
                                 feed_dict={noise_img: sample_noise})
            plot_images(samples)
        train_loss_d = sess.run(d_loss, 
                               feed_dict = {real_img:batch_images,
                                           noise_img:batch_noise})
        
        train_loss_d_real = sess.run(d_loss_real,
                                    feed_dict = {real_img: batch_images,
                                                noise_img:batch_noise})
        
        train_loss_d_fake = sess.run(d_loss_fake,
                                    feed_dict = {real_img:batch_images,
                                                noise_img:batch_noise})
        
        train_loss_g = sess.run(g_loss, 
                               feed_dict = {noise_img: batch_noise})
        
        print("Epoch {}/{}...".format(e +1, epochs),
             "Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake),
              "Generator Loss: {:.4f}".format(train_loss_g))
        
        losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
        
        saver.save(sess, './checkpoints/generator.ckpt')
Epoch 1/300... Discriminator Loss: 0.0231(Real: 0.0019 + Fake: 0.0212)... Generator Loss: 3.5849
Epoch 2/300... Discriminator Loss: 0.4172(Real: 0.1970 + Fake: 0.2202)... Generator Loss: 2.4558
Epoch 3/300... Discriminator Loss: 0.9482(Real: 0.5790 + Fake: 0.3692)... Generator Loss: 4.4707
Epoch 4/300... Discriminator Loss: 1.2668(Real: 0.3952 + Fake: 0.8716)... Generator Loss: 4.6742
Epoch 5/300... Discriminator Loss: 0.8096(Real: 0.2515 + Fake: 0.5580)... Generator Loss: 1.3327
Epoch 6/300... Discriminator Loss: 0.9048(Real: 0.3878 + Fake: 0.5171)... Generator Loss: 1.3625
Epoch 7/300... Discriminator Loss: 0.4552(Real: 0.1176 + Fake: 0.3376)... Generator Loss: 1.9976
Epoch 8/300... Discriminator Loss: 0.5582(Real: 0.2874 + Fake: 0.2708)... Generator Loss: 2.1339
Epoch 9/300... Discriminator Loss: 0.2488(Real: 0.1157 + Fake: 0.1331)... Generator Loss: 2.7862
Epoch 10/300... Discriminator Loss: 0.6615(Real: 0.4319 + Fake: 0.2296)... Generator Loss: 2.2782
Epoch 11/300... Discriminator Loss: 1.5718(Real: 0.5512 + Fake: 1.0205)... Generator Loss: 0.9963
Epoch 12/300... Discriminator Loss: 0.7446(Real: 0.4008 + Fake: 0.3438)... Generator Loss: 2.1626
Epoch 13/300... Discriminator Loss: 0.9990(Real: 0.4842 + Fake: 0.5148)... Generator Loss: 1.8565
Epoch 14/300... Discriminator Loss: 0.6637(Real: 0.3979 + Fake: 0.2658)... Generator Loss: 2.1881
Epoch 15/300... Discriminator Loss: 1.0462(Real: 0.6716 + Fake: 0.3746)... Generator Loss: 1.7844
Epoch 16/300... Discriminator Loss: 1.3157(Real: 0.6331 + Fake: 0.6826)... Generator Loss: 1.7272
Epoch 17/300... Discriminator Loss: 1.2351(Real: 0.9318 + Fake: 0.3033)... Generator Loss: 2.3078
Epoch 18/300... Discriminator Loss: 0.9484(Real: 0.5767 + Fake: 0.3717)... Generator Loss: 1.7968
Epoch 19/300... Discriminator Loss: 0.5014(Real: 0.2029 + Fake: 0.2985)... Generator Loss: 2.0219
Epoch 20/300... Discriminator Loss: 1.1535(Real: 0.7278 + Fake: 0.4257)... Generator Loss: 2.3795
Epoch 21/300... Discriminator Loss: 0.8173(Real: 0.4410 + Fake: 0.3763)... Generator Loss: 2.0387
fig, ax = plt.subplots(figsize = (20, 7))
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator Total Loss')
plt.plot(losses.T[1], label='Discriminator Real Loss')
plt.plot(losses.T[2], label='Discriminator Fake Loss')
plt.plot(losses.T[3], label='Generator')
plt.title("Training Losses")
plt.legend()