GANs故障模式的识别和监控方法

286 阅读17分钟

生成式对抗网络是两个子网络的组合,它们在训练时相互竞争,以产生真实的数据。生成器网络生成看起来真实的人工数据,而鉴别器网络则识别数据是人工的还是真实的。

虽然GANs是强大的模型,但它们可能是相当难训练的。我们同时训练Generator和Discriminator,但要牺牲彼此的利益。这是一个动态系统,只要一个模型的参数被更新,优化问题的性质就会发生变化,正因为如此,达到收敛会很困难。

训练也会导致GANs无法对完整的分布进行建模,这也被称为模式崩溃

在这篇文章中。

  • 我们将看到如何训练一个稳定的GAN模型
  • 然后将在训练过程中玩一玩,了解模式失败的可能原因。

在过去的几年里,我一直在训练GAN,我观察到,GAN中通常的失败模式是模式崩溃收敛失败,我们将在这篇文章中讨论。

训练一个稳定的GAN网络

为了理解失败(在训练GAN中)是如何发生的,让我们首先训练一个稳定的GAN网络。我们将使用MNIST数据集,我们的目标是使用发生器网络从随机噪声中生成人工手写数字。

生成器将把随机噪声作为输入,输出是大小为28×28的假手写数字。鉴别器将从生成器和地面真相中获取28×28的图像输入,并尝试对它们进行正确分类。

我采用了0.0002的学习率,adam优化器,以及0.5作为adam优化器的动力。

让我们来看看我们的稳定GAN网络的代码。首先,让我们进行必要的导入。

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from tqdm import tqdm

import neptune.new as neptune
from neptune.new.types import File

注意,我们将在这个练习中使用PyTorch来训练我们的模型,并使用neptune.ai的仪表盘来进行实验跟踪。这里有一个链接,可以看到我所有的实验。我在colab中运行脚本,Neptune让我非常容易跟踪所有的实验。

在这种情况下,适当的实验跟踪真的很重要,因为损失图和中间图像对识别是否存在故障模式有很大的帮助。另外,你也可以使用matplotlib、sacredTensorBoard等,这取决于你的用例和舒适度。

我们首先初始化Neptune的运行,一旦你在Neptune仪表板上创建了一个项目,你可以得到项目路径和API令牌

run = neptune.init(
   project="project name",
   api_token="You API token",
)

我们保持批次大小为1024,我们将运行100个epochs。潜伏维度被初始化以生成随机数据用于生成器输入。而样本量将用于在每个历时中推断出64个图像,这样我们就可以在每个历时后直观地看到图像的质量。 k是我们打算运行判别器的步骤数。

batch_size = 1024
epochs = 100
sample_size = 64
latent_dim = 128
k = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
           ])

现在,我们下载MNIST数据并创建Dataloader对象。

train_data = datasets.MNIST(
   root='../input/data',
   train=True,
   download=True,
   transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

最后,我们定义一些用于训练的超参数,并使用run对象将其传递给Neptune仪表板。

params = {"learning_rate": 0.0002,
         "optimizer": "Adam",
         "optimizer_betas": (0.5, 0.999),
         "latent_dim": latent_dim}

run["parameters"] = params

这就是我们定义生成器和判别器网络的地方。

生成器网络

  • 生成器模型将潜伏空间作为输入,它是一个随机的噪声。
  • 在第一层,我们将潜伏空间(维数为128)改为128个通道的特征空间,每个通道的高度和宽度为7×7。
  • 接下来是两个去卷积层,增加我们特征空间的高度和宽度。
  • 接着是一个带有tanh激活的卷积层,生成一个通道和28×28高度和宽度的图像。
class Generator(nn.Module):
   def __init__(self, latent_space):
       super(Generator, self).__init__()
       self.latent_space = latent_space
       self.fcn = nn.Sequential(
           nn.Linear(in_features=self.latent_space, out_features=128*7*7),
           nn.LeakyReLU(0.2),
       )

       self.deconv = nn.Sequential(
           nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
           nn.LeakyReLU(0.2),

           nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
           nn.LeakyReLU(0.2),

           nn.Conv2d(in_channels=128, out_channels=1, kernel_size=(3, 3), padding=(1, 1)),
           nn.Tanh()
       )

   def forward(self, x):
       x = self.fcn(x)
       x = x.view(-1, 128, 7, 7)
       x = self.deconv(x)
       return x

鉴别器网络

  • 我们的鉴别器网络由两个卷积层组成,从来自生成器的图像和真实图像中生成特征。
  • 然后是一个分类器层,它对图像进行分类,以确定判别器所预测的图像是真的还是假的。
class Discriminator(nn.Module):
   def __init__(self):
       super(Discriminator, self).__init__()
       self.conv = nn.Sequential(
           nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
           nn.LeakyReLU(0.2),

           nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
           nn.LeakyReLU(0.2)
       )
       self.classifier = nn.Sequential(
           nn.Linear(in_features=3136, out_features=1),
           nn.Sigmoid()
       )

   def forward(self, x):
       x = self.conv(x)
       x = x.view(x.size(0), -1)
       x = self.classifier(x)
       return x

现在我们对生成器和鉴别器网络进行初始化,并对优化器和损失函数进行优化。

我们还有一些辅助函数,用于为假图像和真图像创建标签(其中大小为批次大小),并为生成器输入创建噪声函数。

generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

optim_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

criterion = nn.BCELoss()
def label_real(size):
   labels = torch.ones(size, 1)
   return labels.to(device)


def label_fake(size):
   labels = torch.zeros(size, 1)
   return labels.to(device)


def create_noise(sample_size, latent_dim):
   return torch.randn(sample_size, latent_dim).to(device)

生成器训练函数

现在我们要训练生成器。

  • 生成器接收随机噪声并给出假图像。
  • 这些假图像然后被送到鉴别器,现在我们将真实标签和鉴别器预测的假图像之间的损失降到最低。
  • 从这个函数中我们将观察到发生器的损失。
def train_generator(optimizer, data_fake):
   b_size = data_fake.size(0)
   real_label = label_real(b_size)
   optimizer.zero_grad()
   output = discriminator(data_fake)
   loss = criterion(output, real_label)
   loss.backward()
   optimizer.step()
   return loss

鉴别器训练函数

我们创建一个函数train_discriminator。

  • 正如我们所知,这个网络在训练时接受来自地面真实(即真实图像)和生成器网络(即假图像)的输入。
  • 我们一个接一个地传递假图像和真图像,计算损失并进行反推。我们将观察两个判别器的损失;真实图像的损失(loss_real)和假图像的损失(loss_fake)。
def train_discriminator(optimizer, data_real, data_fake):
   b_size = data_real.size(0)
   real_label = label_real(b_size)
   fake_label = label_fake(b_size)
   optimizer.zero_grad()
   output_real = discriminator(data_real)
   loss_real = criterion(output_real, real_label)
   output_fake = discriminator(data_fake)
   loss_fake = criterion(output_fake, fake_label)
   loss_real.backward()
   loss_fake.backward()
   optimizer.step()
   return loss_real, loss_fake

GAN模型训练

现在我们有了所有的函数,让我们训练我们的模型,看看观察结果,以确定训练是否稳定。

  • 第一行中的噪声将被用来推断每个历时后的中间图像。我们保持噪音不变,这样我们就可以比较不同历时的图像。
  • 现在,对于每个历时,我们训练判别器k次(在这种情况下是一次,因为k=1),对于每个时间的发生器进行训练。
  • 所有的损失都被记录下来,并被发送到Neptune仪表盘上进行绘制。我们不需要把它们附在一个列表中,使用Neptune dashboard,我们可以即时绘制损失图。它还会将每一步的损失记录在一个.csv文件中。
  • 我已经用[点]上传功能将每个纪元后生成的图像保存在Neptune元数据中。
noise = create_noise(sample_size, latent_dim)
generator.train()
discriminator.train()

for epoch in range(epochs):
   loss_g = 0.0
   loss_d_real = 0.0
   loss_d_fake = 0.0

   # training
   for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data) / train_loader.batch_size)):
       image, _ = data
       image = image.to(device)
       b_size = len(image)
       for step in range(k):
           data_fake = generator(create_noise(b_size, latent_dim)).detach()
           data_real = image
           loss_d_fake_real = train_discriminator(optim_d, data_real, data_fake)
           loss_d_real += loss_d_fake_real[0]
           loss_d_fake += loss_d_fake_real[1]
       data_fake = generator(create_noise(b_size, latent_dim))
       loss_g += train_generator(optim_g, data_fake)

   # inference and observations
   generated_img = generator(noise).cpu().detach()
   generated_img = make_grid(generated_img)
   generated_img = np.moveaxis(generated_img.numpy(), 0, -1)
   run[f'generated_img/{epoch}'].upload(File.as_image(generated_img))
   epoch_loss_g = loss_g / bi
   epoch_loss_d_real = loss_d_real/bi
   epoch_loss_d_fake = loss_d_fake/bi
   run["train/loss_generator"].log(epoch_loss_g)
   run["train/loss_discriminator_real"].log(epoch_loss_d_real)
   run["train/loss_discriminator_fake"].log(epoch_loss_d_fake)

   print(f"Epoch {epoch} of {epochs}")
   print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss fake: {epoch_loss_d_fake:.8f}, Discriminator loss real: {epoch_loss_d_real:.8f}")

让我们来看看中间的图像。

第10纪元

Digits generated from a Stable GAN at 10th Epoch

图1 - 第10个纪元时从稳定的GAN中生成的数字 | 来源。来源:作者

这些是在第10纪元生成的64个数字。

第100纪元

Digits generated from a stable GAN at the 100th epoch

图2 - 在第100个纪元时从稳定的GAN中生成的数字|来源:作者。作者

这些是在第100个纪元使用相同的噪声生成的。这些看起来比第10个纪元的图像要好得多,在这里我们实际上可以识别不同的数字。我们可以训练更多的历时,或者调整超参数以获得更好的图像质量。

损失图

你可以很容易地进入Neptune仪表板中的 "添加新仪表板",将不同的损失图合并成一个。

loss_stable_gans

图3 - 损失图,三条线表示生成器、鉴别器上的假图像和鉴别器上的真图像的损失

在图3中,你可以看到损失在第40个纪元后趋于稳定。真实和虚假图像的鉴别器损失保持在0.6左右,而生成器的损失则在0.8左右。上图是稳定训练的预期图。我们可以将其视为基线,并尝试改变k(鉴别器的训练步骤),增加历时的数量等。

现在我们已经建立了一个稳定的GAN模型,让我们来看看失败模式。

GAN失败模式

在过去的几年中,我们看到GAN的应用迅速增加,无论是提高图像的分辨率、条件生成,还是生成类似真实的合成数据。

训练失败是此类应用的一个难题。

如何识别GAN的失败模式?我们如何知道是否有失败的模式。

  • 生成器最好能产生各种数据。如果它产生的是单一种类或类似的输出集合,那就有模式崩溃了。
  • 当一组视觉上不好的数据被生成时,这可能是一个收敛失败的案例。

什么原因导致GAN中的模式崩溃?失败模式的原因。

  • 无法找到网络的收敛性。
  • 生成器可以找到某种类型的数据,可以轻易地骗过判别器。它会在目标实现的假设下,一次又一次地生成相同的数据。整个系统可以过度优化到该单一类型的输出。

识别模式崩溃和其他故障模式的问题是,我们不能依靠定性分析(如手动查看数据)。如果有大量的数据或者问题真的很复杂(我们不会总是产生数字),这种方法就会失败。

评估失败模式

在本节中,我们将尝试了解如何识别是否存在模式崩溃或收敛失败。我们将看到三种评估方法。其中一种我们已经在上一节中讨论过了。

观察中间图像

让我们看一些例子,其中,从中间图像可以评价模式塌陷和收敛。在图4中,我们看到质量非常差的图像,在图5中,我们可以看到生成的同一组图像。

Output from one of the unstable training

图4 - 这些是其中一个不稳定训练的输出。这是在上述相同的训练代码和稍加调整的超参数上,但即使在300个历时之后,你也可以看到我们的图像是多么糟糕--这是一个收敛失败的例子|来源:中国新闻网。作者

This is another example, you can see the same kind of images generated indicating Mode Collapse

图5 - 这是另一个例子,你可以看到生成的图像也是这样的,表明模式崩溃|来源:作者

图4是一个收敛失败的例子,而图5则显示模式崩溃。你可以通过手动查看图像来了解你的模型是如何表现的。但是当问题的复杂性很高或者训练数据太大,你可能无法识别模式崩溃。

让我们来看看一些更好的方法。

通过观察损失图

通过观察损失图,我们可以知道很多事情的发生。例如,在图3中,你可以注意到损失在某一点后达到饱和,显示了预期的行为。现在让我们看一下图6中的这个损失图,我减少了潜伏维度,所以行为是不稳定的。

gan_loss_latentSpace2

图6 - 潜伏维度减少时的损失图 |来源

我们可以在图6中看到,生成器的损失在1和1.2左右震荡。虽然假图像和真图像的鉴别器损失也在0.6左右徘徊,但损失比我们在稳定版中注意到的要多一些。

我建议,即使图形有很高的方差,也是可以的。你可以增加epochs的数量,并等待一些更多的时间让它变得稳定,最重要的是继续检查生成的中间图像。

如果生成器和鉴别器的损失图在最初的历时中下降到零,那么这也是一个问题。这意味着生成器已经找到了一组非常容易让鉴别器识别的假图像。

统计学上有差异的仓位数(NDB得分)

与上述两种定性的方法不同,NDB得分是一种定量的方法。因此,NDB得分可以识别是否有模式崩溃,而不是翻看图像和损失图,而错过了什么或没有做出正确的解释。

让我们来了解NDB评分是如何工作的。

  • 我们有两个集子,一个训练集(在此基础上训练模型)和一个测试集(在训练完成后由随机噪声发生器生成的假图像)。
  • 现在用K-means聚类法将训练集分成K个群组。这些将是我们的K个不同的仓。
  • 现在,根据测试数据点与K个聚类中心点之间的欧氏距离,将测试数据分配到这K个仓中。
  • 现在对每个仓的训练样本和测试样本进行两样测试,并计算出Z-score。如果Z-score小于阈值(本文中使用了0.05),将该仓标记为统计学上的不同。
  • 计算统计学上不同的仓的数量,并将其除以K。
  • 收到的数值将介于0和1之间。

统计学上不同的仓数多意味着,即数值接近于1,意味着模式塌陷程度高,意味着模型不好。然而,NDB分数接近0意味着较少或没有模式塌陷。

NDB评估方法来自于《GANs和GMMs》一文。

(a)Top Left - Image from Training dataset (b)Bottom Left - Image from Test dataset and the overlap is shown (c)Bar Graph showing bins for train and test set

图7 - (a) 左上角 - 训练数据集的图像 (b) 左下角 - 测试数据集的图像,重叠部分显示出来 (c) 条形图显示训练和测试集的分类。

一个非常好的计算NDB的代码可以在Kevin Shen实验笔记本中找到。

解决故障模式

现在我们已经了解了如何识别GANs训练中的问题,我们将看看一些解决方案和经验法则来解决这些问题。其中一些将是基本的超参数调整。如果你想多走一步来稳定你的GANs,我们会讨论一些算法。

成本函数

有论文说,没有损失函数是优越的。我建议你从更容易的损失函数开始,比如我们使用的二元交叉熵,然后从那里开始提升。

现在,这并不是强迫你在某些GAN架构中使用某些损失函数。但是在撰写这些论文时进行了大量的研究,其中有很多仍在进行中。因此,使用图8中的这些损失函数将是很好的做法,这可能有助于你防止模式崩溃和收敛。

Architecture of GANs and corresponding loss functions used in papers

图8 - GANs的结构和论文中使用的相应损失函数 |来源

在不同的损失函数上进行实验,注意你的损失函数可能因为超参数的错误调整而失败,比如让优化器过于激进,或者学习率过大。我们将在后面详细讨论这些问题。

潜伏空间

潜伏空间是对发生器的输入(随机噪声)进行采样的地方。现在,如果你限制潜伏空间,它将产生更多相同类型的输出,从图9可以看出。你也可以看一下图6中相应的损失图。

Subplot at 100th epoch, when latent space is 2

图9 - 潜伏空间为2时,第100个历时的子图|来源:中国新闻网作者

在图9中你能看到这么多类似的8和7吗?因此,模式崩溃了。

gan_loss_latentSpace1

图10 - 这里我把潜伏空间定为1,运行了200个历时。
我们可以看到发电机的损耗不断增加,所有的损耗都在震荡。

Subplot corresponding Fig. 10, where the latent space is 1. These digits are generated on the 200th epoch.

图11 - 与图10相对应的子图,其中潜伏空间为1。
这些数字是在第200个历时上产生的。来源:作者

请注意,在训练GAN网络时,提供足够数量的潜伏空间是至关重要的,这样生成器就可以创造出各种特征。

学习率

我在训练GAN时观察到的最常见的问题之一是高学习率。它导致了模式崩溃或不收敛。保持低的学习率真的很重要,低至0.0002甚至更低。

gan_loss_lr_0.2

图12 - 学习率为0.2时的损失值 |来源

Generated Images on 100th epoch, with a learning rate of 0.2

图13 - 在第100个历时中生成的图像,学习率为0.2 | 来源:中国新闻网。作者

从图12的损失图中我们可以清楚地看到,判别器将所有的图像都识别为真实的。这就是为什么假图像的损失很高而真图像的损失为零。现在,生成器假设它生成的所有图像都能骗过鉴别器。这里的问题是,由于如此高的学习率,鉴别器没有得到哪怕一点点的训练。

批量规模越大,学习率的值就越高,但总是要尽量保持安全的一面。

优化器

一个积极的修改器对训练GAN来说是个坏消息。它导致无法找到生成器损失和鉴别器损失之间的平衡点,从而导致收敛失败。

gan_loss_adam_betas_0.9_0.999

图14 - 亚当优化器默认值下的损失图(贝塔斯为0.9和0.999)|来源

在亚当优化器中,betas是用于计算梯度的运行平均值及其平方的超参数。我们最初(在稳定的训练中)对β1使用的是0.5的值。将其改为0.9(默认值)可以提高优化器的积极性。

在图14中,判别器表现良好。由于生成器的损失在增加,我们可以知道它产生了如此糟糕的图像,以至于鉴别器真的很容易将它们归类为假的。损失图并没有达到平衡。

特征匹配

特征匹配提出了一个新的目标函数,其中我们不直接使用鉴别器的输出。生成器被训练成这样,生成器的输出有望与鉴别器的中间特征上的真实图像的值相匹配。

f(x) is the feature vector extracted at the intermediate layer of the discriminator

图15 - f(x)是在鉴别器的中间层提取的特征向量

对于真实图像和虚假图像,特征向量(图15中的f(x))是在中间层上分批计算的,并测量这些特征向量的平均值的L2距离。

将生成的数据与真实数据的统计数据相匹配是更有意义的。如果优化器在寻找最佳数据生成时变得过于贪婪,并且从未达到收敛,那么特征匹配就会有帮助。

历史平均法

我们保持之前t个模型的参数(θ)的运行平均值。现在我们对模型进行惩罚,用以前的参数给成本函数增加一个L2成本。

这里,θ[i]是第i次运行的参数值。

在处理非凸的目标函数时,历史平均法可以帮助收敛模型。

结论

  • 我们现在明白了在训练GAN时实验跟踪的重要性。
  • 了解损失图并仔细观察产生的中间数据是很重要的。
  • 像学习率、优化器参数、潜伏空间等超参数如果调整不当,会毁掉你的模型。
  • 随着过去几年GAN模型的增加,越来越多的研究进入了稳定GAN的训练。还有很多技术对特定的使用情况有好处。