infoGAN

316 阅读3分钟
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
WARNING:tensorflow:From <ipython-input-2-93d8da72a918>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
def get_inputs(noise_dim, image_height, image_width, image_depth):
    inputs_real = tf.placeholder(tf.float32, [None, image_height, image_width, image_depth], name='inputs_real')
    inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')
    condition_label = tf.placeholder(tf.float32, [None, 10], name='condition_label')
    return inputs_real, inputs_noise, condition_label
def generator(noise_img, output_dim, condition_label, is_train=True):
    with tf.variable_scope("generator", reuse=(not is_train)):
        #100*1 to 4*4*512
        #全连接层
        noise_img_ = tf.concat([noise_img, condition_label], 1)
        layer1 = tf.layers.dense(noise_img_, 4*4*512)
        layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
        #batch normalization
        layer1 = tf.layers.batch_normalization(layer1, training=is_train)
        layer1 = tf.nn.relu(layer1)
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        
        # 4*4*512 to 7*7*256
        layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides = 1, padding='valid')
        layer2 = tf.layers.batch_normalization(layer2, training=is_train)
        layer2 = tf.nn.relu(layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
        
        # 7*7*256 to 14*14*128
        layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=is_train)
        layer3 = tf.nn.relu(layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
        
        #14*14*128 to 28*28*1
        logits = tf.layers.conv2d_transpose(layer3, output_dim, 3, strides=2, padding='same')
        outputs = tf.tanh(logits)
        return outputs
def discriminator(inputs_img, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        # 28*28*1 to 14*14*128
        #第一层不加入BN
        layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
        layer1 = tf.maximum(alpha*layer1, layer1)
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        
        # 14*14*128 to 7*7*256
        layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.maximum(alpha*layer2, layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
        
        # 7*7*256 to 4*4*512
        layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.maximum(alpha*layer3, layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
        
        flatten = tf.reshape(layer3, (-1, 4*4*512))
        logits = tf.layers.dense(flatten, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs
def get_Q(g_out, reuse=False, alpha=0.01):
    with tf.variable_scope("Q", reuse=reuse):
        # 28*28*1 to 14*14*128
        #第一层不加入BN
        layer1 = tf.layers.conv2d(g_out, 128, 3, strides=2, padding='same')
        layer1 = tf.maximum(alpha*layer1, layer1)
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        
        # 14*14*128 to 7*7*256
        layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.maximum(alpha*layer2, layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
        
        # 7*7*256 to 4*4*512
        layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.maximum(alpha*layer3, layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
        
        flatten = tf.reshape(layer3, (-1, 4*4*512))
        logits = tf.layers.dense(flatten, 10)
        outputs = tf.nn.softmax(logits)
        
        return outputs
def get_loss(inputs_real, inputs_noise, condition_label, image_depth, smooth=0.1):
    g_outputs = generator(inputs_noise, image_depth, condition_label, is_train=True)
    d_logits_real, d_outputs_real = discriminator(inputs_real)
    d_logits_fake, d_outputs_fake = discriminator(g_outputs, reuse=True)
    q_c = get_Q(g_outputs)
    
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_real,
                                                                    labels = tf.ones_like(d_logits_real)) * (1 - smooth))

    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                                    labels = tf.zeros_like(d_logits_fake)))

    d_loss = tf.add(d_loss_real, d_loss_fake)

    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                               labels = tf.ones_like(d_logits_fake)) * (1 - smooth))
    
    q_loss = tf.reduce_mean(-tf.reduce_sum(tf.log(q_c + 1e-8) * condition_label, 1))
    
    return g_loss, d_loss, q_loss
def get_optimizer(g_loss, d_loss, q_loss, betal=0.4, learning_rate=0.001):
    train_vars = tf.trainable_variables()
    
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    q_vars = [var for var in train_vars if var.name.startswith("Q")]
    
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate, beta1=betal).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate, beta1=betal).minimize(d_loss, var_list=d_vars)
        q_opt = tf.train.AdamOptimizer(learning_rate, beta1=betal).minimize(q_loss, var_list=g_vars + q_vars)
    return g_opt, d_opt, q_opt
def plot_images(samples):
    samples = (samples +1)/2
    fig, axes = plt.subplots(nrows =1, ncols=25, sharex = True, sharey = True, figsize =(50, 2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape(28, 28), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad = 0)
def show_generator_output(sess, n_images, inputs_noise, output_dim, condition_label):
    noise_shape = inputs_noise.get_shape().as_list()[-1]
    
    examples_noise = np.random.uniform(-1, 1, size=[n_images, noise_shape])
    condition_label_test = mnist.train.labels[50:75]
    samples = sess.run(get_generator(inputs_noise, output_dim, condition_label, False),
                      feed_dict={inputs_noise:examples_noise, condition_label:condition_label_test})
    result = np.squeeze(samples, -1)
    return result
batch_size = 64
noise_size = 100
epochs = 5
n_samples = 25
learning_rate = 0.001
betal = 0.4

def train(noise_size, data_shape, batch_size, n_samples):
    losses = []
    steps = 0
    
    inputs_real, inputs_noise, condition_label = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
    g_loss, d_loss, q_loss = get_loss(inputs_real, inputs_noise, condition_label, data_shape[-1])
    g_train_opt, d_train_opt, q_train_opt = get_optimizer(g_loss, d_loss, q_loss, betal, learning_rate)
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for e in range(epochs):
            for batch_i in range(mnist.train.num_examples//batch_size):
                steps +=1
                batch_images_, batch_labels = mnist.train.next_batch(batch_size)
                
                batch_images = batch_images_.reshape(batch_size, data_shape[1], data_shape[2], data_shape[3])
                batch_images = batch_images*2 -1
                
                batch_noise =np.random.uniform(-1, 1, size=(batch_size, noise_size))
                c_labels = np.random.multinomial(1, 10*[0.1], size = batch_size)
                
                _ = sess.run(g_train_opt, feed_dict={inputs_real:batch_images,
                                                    inputs_noise:batch_noise,
                                                    condition_label:c_labels})
                _ = sess.run(d_train_opt, feed_dict={inputs_real:batch_images,
                                                    inputs_noise:batch_noise,
                                                    condition_label:c_labels})
                _ = sess.run(q_train_opt, feed_dict={inputs_real:batch_images,
                                                    inputs_noise:batch_noise,
                                                    condition_label:c_labels})
                
                if steps % 10 ==0:
                    train_loss_d = d_loss.eval({inputs_real:batch_images, 
                                               inputs_noise:batch_noise,
                                               condition_label:c_labels})
                    train_loss_g = g_loss.eval({inputs_real:batch_images,
                                               inputs_noise:batch_noise,
                                               condition_label:c_labels})
                    train_loss_q = q_loss.eval({inputs_real:batch_images,
                                               inputs_noise:batch_noise,
                                               condition_label:c_labels})
                    losses.append((train_loss_d, train_loss_g, train_loss_q))
                    
                    c_labels = tf.to_float(c_labels)
                    samples = sess.run(generator(batch_noise, data_shape[-1], c_labels, False))
                    plot_images(samples)
                    print("Epoch {}/{}....".format(e+1, epochs),
                         "Discriminator Loss: {:.4f}....".format(train_loss_d),
                         "Generator Loss: {:.4f}....".format((train_loss_g)))
            saver.save(sess, "./model6.ckpt")
with tf.Graph().as_default():
    train(noise_size, [-1, 28, 28, 1], batch_size, n_samples)
Epoch 1/5.... Discriminator Loss: 0.0251.... Generator Loss: 3.8968....
Epoch 1/5.... Discriminator Loss: 0.0754.... Generator Loss: 4.3442....
Epoch 1/5.... Discriminator Loss: 1.8414.... Generator Loss: 12.9266....
Epoch 1/5.... Discriminator Loss: 0.0510.... Generator Loss: 6.1210....
Epoch 1/5.... Discriminator Loss: 0.0036.... Generator Loss: 6.3787....
Epoch 1/5.... Discriminator Loss: 0.0214.... Generator Loss: 6.2205....
Epoch 1/5.... Discriminator Loss: 0.0467.... Generator Loss: 11.2260....
data_shape = [-1, 28, 28, 1]
inputs_real, inputs_noise, condition_label = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
g_loss, d_loss, q_loss = get_loss(inputs_real, inputs_noise, condition_label, data_shape[-1])
g_train_opt, d_train_opt, q_train_opt = get_optimizer(g_loss, d_loss, q_loss, betal, learning_rate)
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, tf.train.latest_checkpoint('./'))
batch_size = 25
noise_size = 100
batch_noise = np.random.uniform(-1, 1, size = (batch_size, noise_size))
c_labels = np.zeros(batch_size, 10)
c_labels[:, 3] = 1
c_labels = tf.cast(c_labels, tf.float32)
batch_noise = tf.cast(batch_noise, tf.float32)
samples = sess.run(generator(batch_noise, 1, c_labels, False))
plot_images(samples)