持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第12天,点击查看活动详情
生成对抗网络(GAN)与前面介绍过的resnet和unet类似,也是一种网络模型,它主要由两部分构成,分别是生成器和判别器。
首先阐述一下什么是生成器,其结构就是一个图片生成模型,它根据随机的噪声来生成一组图片。
然后是判别器,判别器根据生成器生成的图片进行一个度量,也就是看一下与ground truth的差异,也就是打假,如果两幅图比较后输出为1,表明这就是真实图片,否则就给出一个小于1的概率,认为这是生成的图片与真实图片的相似度。
然后,对于整个网络的训练,我们使用下面的损失函数进行求解,仅需要找到损失函数的最小值,即此时,生成器生成的图片,判别器已经无法区分两者的真假。
关于对于损失函数的相关推导,本博客就不详细的叙述了,可以在GAN提出的论文中进行阅读,在论文中作者也给出了详细的说明。 下面通过一组pytorch代码来了解生成器和判别器的代码
首先是生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 定义了一个网络块
def block(in_feat, out_feat, normalize=True):
# 线性层
layers = [nn.Linear(in_feat, out_feat)]
# 归一化
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
# leakrelu层
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
# 构造了一个容器,用于存放每一层网络
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
# 网络的传播
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img
下面是判别器层
def discriminator_model(nn.Module):
def __init__(self):
super(discriminator_model, self).__init__()
self.model = nn.Sequential(
Conv2D(64, (5, 5), padding='same', input_shape=(64, 64, 3))
Activation('tanh')
MaxPooling2D(pool_size=(2, 2))
Conv2D(128, (5, 5))
Activation('tanh')
MaxPooling2D(pool_size=(2, 2))
Flatten())
Dense(1024))
Activation('tanh')
Dense(1))
Activation('sigmoid')
)
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img
return model
对于生成器和判别器的代码结构都相对清晰,构建过程没有复杂的结构和方法,只需要在使用时了解其结构和原理即可。本文讲解的主要为最原始的gan结构,目前随着各个GAN版本的出现和修改,各种改进后的版本层出不穷,需要的也可以根据提出者发表的论文进行修改和复现,以适应自己的任务。