SSGAN

424 阅读3分钟
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
WARNING:tensorflow:From <ipython-input-2-93d8da72a918>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
def huber_loss(labels, predictions, delta=1.0):
    residual = tf.abs(predictions - labels)
    condition = tf.less(residual, delta)
    small_res = 0.5 * tf.square(residual)
    large_res = delta * residual - 0.5 * tf.square(delta)
    return tf.where(condition, small_res, large_res)
def generator(noise_img, is_train=True):
    with tf.variable_scope("generator", reuse=(not is_train)):
        # 100 x 1 to 4 x 4 x 512
        # 全连接层
        layer1 = tf.layers.dense(noise_img, 4*4*512)
        layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
        # batch normalization
        layer1 = tf.layers.batch_normalization(layer1, training=is_train)
        layer1 = tf.nn.relu(layer1)
        # dropout
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        
        # 4 x 4 x 512 to 7 x 7 x 256
        layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
        layer2 = tf.layers.batch_normalization(layer2, training=is_train)
        layer2 = tf.nn.relu(layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
        
        # 7 x 7 256 to 14 x 14 x 128
        layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=is_train)
        layer3 = tf.nn.relu(layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
        
        
        # 14 x 14 x 128 to 28 x 28 x 1
        logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
        
        outputs = tf.tanh(logits)
        return outputs
def discriminator(inputs_img, reuse=False, alpha=0.2):
     with tf.variable_scope("discriminator", reuse=reuse):
        # 28 x 28 x 1 to 14 x 14 x 128
        # 第一层不加入BN
        layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
        layer1 = tf.maximum(alpha * layer1, layer1)
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        
        # 14 x 14 x 128 to 7 x 7 x 256
        layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.maximum(alpha * layer2, layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
        
        # 7 x 7 x 256 to 4 x 4 x 512
        layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.maximum(alpha * layer3, layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
        
        # 4 x 4 x 512 to 4*4*512 x 1
        flatten = tf.reshape(layer3, (-1, 4*4*512))
        logits = tf.layers.dense(flatten, 11)
        outputs = tf.nn.softmax(logits)
        
        return logits, outputs
def get_loss(inputs_real, inputs_noise, input_label_real, input_label_fake, smooth=0.1):
    g_outputs = generator(inputs_noise, is_train=True)
    d_logits_real, d_outputs_real = discriminator(inputs_real)
    d_logits_fake, d_outputs_fake = discriminator(g_outputs, reuse=True)
    
    # 计算Loss
    g_loss = tf.reduce_mean(tf.log(d_outputs_fake[:, -1]))
    g_loss += tf.reduce_mean(huber_loss(inputs_real, g_outputs))*0.0001
    #g_loss_l2 = tf.reduce_mean(tf.square(g_outputs - inputs_real))
    #g_loss = g_loss_ + g_loss_l2
    
    d_loss_real = tf.nn.softmax_cross_entropy_with_logits_v2(logits=d_logits_real, labels=input_label_real)
    d_loss_fake = tf.nn.softmax_cross_entropy_with_logits_v2(logits=d_logits_fake, labels=input_label_fake)
    #[0,..,0,1]
    d_loss = tf.reduce_mean(d_loss_real + d_loss_fake)
    
    return g_loss, d_loss
def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):
    train_vars = tf.trainable_variables()
    
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    
    # Optimizer
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate*5, beta1=beta1).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
    
    return g_opt, d_opt
def plot_images(samples):
    samples = (samples + 1) / 2
    fig, axes = plt.subplots(nrows=1, ncols=10, sharex=True, sharey=True, figsize=(10,1))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape(28,28))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)
def show_generator_output(sess, noise_image):
    samples = sess.run(generator(noise_image, False))
    #samples = sess.run(tf.reshape(samples, [-1, 28, 28, 1]))
    return samples
# 定义参数
b_size = 128
noise_size = 100
epochs = 50
n_samples = 25
learning_rate = 0.001
beta1 = 0.6
def train(noise_size, b_size, n_samples):
    
    # 存储loss
    losses = []
    steps = 0
    
    inputs_real = tf.placeholder(tf.float32, [None, 28, 28, 1], name='inputs_real')
    inputs_noise = tf.placeholder(tf.float32, [None, noise_size], name='inputs_noise')
    input_label_real = tf.placeholder(tf.float32, [None, 11], name='input_label_real')
    input_label_fake = tf.placeholder(tf.float32, [None, 11], name='input_label_fake')
    
#    inputs_real, inputs_noise, input_label_real, input_label_fake = get_inputs(noise_size, 32, 32, 3)
    g_loss, d_loss = get_loss(inputs_real, inputs_noise, input_label_real, input_label_fake)
    g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)

    saver = tf.train.Saver()
    #model_file=tf.train.latest_checkpoint('./')
    with tf.Session() as sess:
        #saver.restore(sess, model_file)
        sess.run(tf.global_variables_initializer())
        
        # 迭代epoch
        for epoch in range(50): 
            for batch_i in range(mnist.train.num_examples//b_size):
                steps += 1
                batch_images_, batch_labels = mnist.train.next_batch(b_size)

                batch_images = batch_images_.reshape((b_size, 28, 28, 1))
                batch_images = batch_images*2 -1
            
                
                alpha = 0.9
                real_label = sess.run(tf.concat([batch_labels, tf.zeros([b_size, 1])], axis=1))
                fake_label = sess.run(tf.concat([(1-alpha)*tf.ones([b_size, 10])/10, alpha*tf.ones([b_size, 1])], axis=1))
                
                # noise
                batch_noise = np.random.uniform(-1, 1, size=(b_size, noise_size))

                # run optimizer
                _ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise,
                                                     input_label_real: real_label,
                                                     input_label_fake: fake_label
                                                     })
                _ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise,
                                                     input_label_real: real_label,
                                                     input_label_fake: fake_label
                                                     })
            if epoch % 5 == 0:
                    #saver.save(sess, "./mode7.ckpt")
                train_loss_d = d_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise,
                                                input_label_real: real_label,
                                                input_label_fake: fake_label})
                train_loss_g = g_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise,
                                                input_label_real: real_label,
                                                input_label_fake: fake_label})
                losses.append((train_loss_d, train_loss_g))
                    # 显示图片
                batch_noise = np.random.uniform(-1, 1, size=(10, noise_size))
                batch_noise = tf.cast(batch_noise, tf.float32)
                samples = show_generator_output(sess, batch_noise)
                plot_images(samples)
                print("Epoch {}/{}....".format(epoch+1, epochs), 
                          "Discriminator Loss: {:.4f}....".format(train_loss_d),
                          "Generator Loss: {:.4f}....". format(train_loss_g))
        saver.save(sess, "./model6.ckpt")
with tf.Graph().as_default():
    train(noise_size, b_size, n_samples)