还没有女朋友的朋友们,你们有福了,学会CycleGAN把男朋友变成女朋友

46 阅读6分钟

image = load(image_file)

image = preprocess_image_test(image)

return image

加载男性图片,构建训练数据集

train_man = tf.data.Dataset.list_files('./man2woman/trainA/*.jpg')

train_man = train_man.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_man = train_man.shuffle(BUFFER_SIZE)

train_man = train_man.batch(BATCH_SIZE, drop_remainder=True)

加载女性图片,构建训练数据集

train_woman = tf.data.Dataset.list_files('./man2woman/trainB/*.jpg')

train_woman = train_woman.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_woman = train_woman.shuffle(BUFFER_SIZE)

train_woman = train_woman.batch(BATCH_SIZE, drop_remainder=True)

模型构建

在 CycleGAN 中,使用实例归一化代替批归一化,但在 tensorflow 中,未包含实例归一化层,因此需要自行实现。

class InstanceNormalization(tf.keras.layers.Layer):

"""Instance Normalization Layer."""

def init(self, epsilon=1e-5):

super(InstanceNormalization, self).init()

self.epsilon = epsilon

def build(self, input_shape):

self.scale = self.add_weight(

name='scale',

shape=input_shape[-1:],

initializer=tf.random_normal_initializer(1., 0.02),

trainable=True)

self.offset = self.add_weight(

name='offset',

shape=input_shape[-1:],

initializer='zeros',

trainable=True)

def call(self, x):

mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)

inv = tf.math.rsqrt(variance + self.epsilon)

normalized = (x - mean) * inv

return self.scale * normalized + self.offset

为了减少代码量,定义上采样块和下采样块:

下采样块

def downsample(filters, size, norm_type='batchnorm', apply_norm=True):

initializer = tf.random_normal_initializer(0., 0.02)

result = tf.keras.Sequential()

result.add(

tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',

kernel_initializer=initializer, use_bias=False))

if apply_norm:

if norm_type.lower() == 'batchnorm':

result.add(tf.keras.layers.BatchNormalization())

elif norm_type.lower() == 'instancenorm':

result.add(InstanceNormalization())

result.add(tf.keras.layers.LeakyReLU())

return result

上采样快

def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):

initializer = tf.random_normal_initializer(0., 0.02)

result = tf.keras.Sequential()

result.add(

tf.keras.layers.Conv2DTranspose(filters, size, strides=2,

padding='same',

kernel_initializer=initializer,

use_bias=False))

if norm_type.lower() == 'batchnorm':

result.add(tf.keras.layers.BatchNormalization())

elif norm_type.lower() == 'instancenorm':

result.add(InstanceNormalization())

if apply_dropout:

result.add(tf.keras.layers.Dropout(0.5))

result.add(tf.keras.layers.ReLU())

return result

接下来构建生成器:

def unet_generator(output_channels, norm_type='batchnorm'):

down_stack = [

downsample(64, 4, norm_type, apply_norm=False),

downsample(128, 4, norm_type),

downsample(256, 4, norm_type),

downsample(512, 4, norm_type),

downsample(512, 4, norm_type),

downsample(512, 4, norm_type),

downsample(512, 4, norm_type),

downsample(512, 4, norm_type),

]

up_stack = [

upsample(512, 4, norm_type, apply_dropout=True),

upsample(512, 4, norm_type, apply_dropout=True),

upsample(512, 4, norm_type, apply_dropout=True),

upsample(512, 4, norm_type),

upsample(256, 4, norm_type),

upsample(128, 4, norm_type),

upsample(64, 4, norm_type),

]

initializer = tf.random_normal_initializer(0., 0.02)

last = tf.keras.layers.Conv2DTranspose(

output_channels, 4, strides=2,

padding='same', kernel_initializer=initializer,

activation='tanh') # (bs, 256, 256, 3)

concat = tf.keras.layers.Concatenate()

inputs = tf.keras.layers.Input(shape=[None, None, 3])

x = inputs

Downsampling through the model

skips = []

for down in down_stack:

x = down(x)

skips.append(x)

skips = reversed(skips[:-1])

Upsampling and establishing the skip connections

for up, skip in zip(up_stack, skips):

x = up(x)

x = concat([x, skip])

x = last(x)

return tf.keras.Model(inputs=inputs, outputs=x)

构建鉴别器:

def discriminator(norm_type='batchnorm', target=True):

initializer = tf.random_normal_initializer(0., 0.02)

inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')

x = inp

if target:

tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')

x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

down1 = downsample(64, 4, norm_type, False)(x) # (bs, 128, 128, 64)

down2 = downsample(128, 4, norm_type)(down1) # (bs, 64, 64, 128)

down3 = downsample(256, 4, norm_type)(down2) # (bs, 32, 32, 256)

zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)

conv = tf.keras.layers.Conv2D(

512, 4, strides=1, kernel_initializer=initializer,

use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

if norm_type.lower() == 'batchnorm':

norm1 = tf.keras.layers.BatchNormalization()(conv)

elif norm_type.lower() == 'instancenorm':

norm1 = InstanceNormalization()(conv)

leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

last = tf.keras.layers.Conv2D(

1, 4, strides=1,

kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

if target:

return tf.keras.Model(inputs=[inp, tar], outputs=last)

else:

return tf.keras.Model(inputs=inp, outputs=last)

实例化生成器与鉴别器:

generator_g = unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

generator_f = unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = discriminator(norm_type='instancenorm', target=False)

discriminator_y = discriminator(norm_type='instancenorm', target=False)

损失函数与优化器的定义:

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

鉴别器损失

def discriminator_loss(real, generated):

real_loss = loss_obj(tf.ones_like(real), real)

generated_loss = loss_obj(tf.zeros_like(generated), generated)

total_disc_loss = real_loss + generated_loss

return total_disc_loss * 0.5

生成器损失

def generator_loss(generated):

return loss_obj(tf.ones_like(generated), generated)

循环一致性损失

def calc_cycle_loss(real_image, cycled_image):

loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

return LAMBDA * loss1

identity loss

def identity_loss(real_image, same_image):

loss = tf.reduce_mean(tf.abs(real_image - same_image))

return LAMBDA * 0.5 * loss

优化器

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

训练结果可视化函数

创建 generate_images 函数用于在训练过程中查看模型效果.

def generate_images(model, test_input):

prediction = model(test_input)

plt.figure(figsize=(12, 12))

display_list = [test_input[0], prediction[0]]

title = ['Input Image', 'Predicted Image']

for i in range(2):

plt.subplot(1, 2, i+1)

plt.title(title[i])

getting the pixel values between [0, 1] to plot it.

plt.imshow(display_list[i] * 0.5 + 0.5)

plt.axis('off')

plt.show()

plt.savefig('results/{}.png'.format(int(time.time())))

训练步骤

首先需要定义训练函数:

@tf.function

def train_step(real_x, real_y):

with tf.GradientTape(persistent=True) as tape:

Generator G translates X -> Y

Generator F translates Y -> X.

fake_y = generator_g(real_x, training=True)

cycled_x = generator_f(fake_y, training=True)

fake_x = generator_f(real_y, training=True)

cycled_y = generator_g(fake_x, training=True)

same_x and same_y are used for identity loss.

same_x = generator_f(real_x, training=True)

same_y = generator_g(real_y, training=True)

disc_real_x = discriminator_x(real_x, training=True)

disc_real_y = discriminator_y(real_y, training=True)

disc_fake_x = discriminator_x(fake_x, training=True)

disc_fake_y = discriminator_y(fake_y, training=True)

calculate the loss

gen_g_loss = generator_loss(disc_fake_y)

gen_f_loss = generator_loss(disc_fake_x)

total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

Total generator loss = adversarial loss + cycle loss

total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)

total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)

disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

Calculate the gradients for generator and discriminator

generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)

generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)

discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)

discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)

Apply the gradients to the optimizer

generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))

generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))

discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))

discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))

最后进行模型的训练:

for epoch in range(EPOCHS):

start = time.time()

n = 0

for image_x, image_y in tf.data.Dataset.zip((train_man, train_woman)):

train_step(image_x, image_y)

generate_images(generator_g, sample_man)

做了那么多年开发,自学了很多门编程语言,我很明白学习资源对于学一门新语言的重要性,这些年也收藏了不少的Python干货,对我来说这些东西确实已经用不到了,但对于准备自学Python的人来说,或许它就是一个宝藏,可以给你省去很多的时间和精力。

别在网上瞎学了,我最近也做了一些资源的更新,只要你是我的粉丝,这期福利你都可拿走。

我先来介绍一下这些东西怎么用,文末抱走。


(1)Python所有方向的学习路线(新版)

这是我花了几天的时间去把Python所有方向的技术点做的整理,形成各个领域的知识点汇总,它的用处就在于,你可以按照上面的知识点去找对应的学习资源,保证自己学得较为全面。

最近我才对这些路线做了一下新的更新,知识体系更全面了。

在这里插入图片描述

(2)Python学习视频

包含了Python入门、爬虫、数据分析和web开发的学习视频,总共100多个,虽然没有那么全面,但是对于入门来说是没问题的,学完这些之后,你可以按照我上面的学习路线去网上找其他的知识资源进行进阶。

在这里插入图片描述

(3)100多个练手项目

我们在看视频学习的时候,不能光动眼动脑不动手,比较科学的学习方法是在理解之后运用它们,这时候练手项目就很适合了,只是里面的项目比较多,水平也是参差不齐,大家可以挑自己能做的项目去练练。

在这里插入图片描述

(4)200多本电子书

这些年我也收藏了很多电子书,大概200多本,有时候带实体书不方便的话,我就会去打开电子书看看,书籍可不一定比视频教程差,尤其是权威的技术书籍。

基本上主流的和经典的都有,这里我就不放图了,版权问题,个人看看是没有问题的。

(5)Python知识点汇总

知识点汇总有点像学习路线,但与学习路线不同的点就在于,知识点汇总更为细致,里面包含了对具体知识点的简单说明,而我们的学习路线则更为抽象和简单,只是为了方便大家只是某个领域你应该学习哪些技术栈。

在这里插入图片描述

(6)其他资料

还有其他的一些东西,比如说我自己出的Python入门图文类教程,没有电脑的时候用手机也可以学习知识,学会了理论之后再去敲代码实践验证,还有Python中文版的库资料、MySQL和HTML标签大全等等,这些都是可以送给粉丝们的东西。

在这里插入图片描述

这些都不是什么非常值钱的东西,但对于没有资源或者资源不是很好的学习者来说确实很不错,你要是用得到的话都可以直接抱走,关注过我的人都知道,这些都是可以拿到的。

了解详情:docs.qq.com/doc/DSnl3ZG…