GAN 系列——BigGAN

79 阅读7分钟

简介

尽管图像生成取得了进展,但成功地从 ImageNet 等复杂数据集生成高分辨率、多样化的样本仍然是一个难以捉摸的目标。为此,论文尝试目前最大规模地训练生成对抗网络,并研究了这种规模特有的不稳定性。其发现对生成器应用正交正则化使其服从一个简单的“截断技巧”,通过减少生成器输入的方差来进行样本保真度和多样性之间的权衡。在 128×128 分辨率的 ImageNet 上训练时,我们的模型(BigGAN) 实现了 166.5 的 Inception score (IS),Fŕechet Inception Distance (FID) 为 7.4,比之前的最佳 IS 提高了 52.52,FID 为18.65。

扩大 GAN

我们探索了扩大 GAN 训练以获得更大模型和更大批次的性能优势的方法。作为基线,我们采用了 SA-GAN 架构,该架构使用 hinge 损失。我们使用类条件 BatchNorm 为 G 提供类信息和通过映射的为 D 提供类信息。在 G 中使用 Spectral Norm,将学习率减半并且每更新 G 一步更新 D 两步。采用 G 权重的移动平均值,衰减系数为 0.9999。我们使用正交初始化。计算 G 中跨设备的 BatchNorm 统计数据。

我们首先增加基线模型的批量大小。简单地将批量大小增加到 8 倍将最先进的 IS 提高了 46%。我们推测这是每批覆盖更多模式的结果,为两个网络提供了更好的梯度。这种增大批数的一个副作用是我们的模型虽然在更少的迭代中达到了更好的最终性能,但变得不稳定。然后,我们将每一层的宽度(通道数)增加 50%,大约是两个模型中的参数数量的两倍。这进一步提升 IS 21%,我们猜想这是由于模型相对于数据集的复杂性增加。简单地加倍网络的深度并没有带来性能的提高。

我们注意到 G 中条件 BatchNorm 层的类嵌入 c 包含大量权重。我们没有为每个嵌入使用单独的层,而是选择使用共享嵌入,该嵌入线性投影到每一层的增益和偏差。这将计算和内存成本降低,并将训练速度(在达到给定性能所需的迭代次数中)提高了 37%。同时,将来自噪声向量 z 的直接通过跳过连接 (skip-z) 输入到 G 的多个层,而不仅仅是初始层。这种设计背后的直觉是让 G 使用潜在空间直接影响不同分辨率和层次结构级别的特征。在 BigGAN 中,这是通过将 z 拆分为一个一个块并将每个块连接到条件向量 c 来完成的,该向量 c 被投影到 BatchNorm 增益和偏差中。在 BigGAN-deep 中,我们使用更简单的设计,将整个 z 与条件向量连接起来,而不将其拆分为块。Skip-z 提供了大约 4% 的适度性能提升,并将训练速度进一步提高了 18%。

值得注意的是,我们最好的结果来自于使用不同于训练中的潜在分布进行采样。采用使用 z ∼ N (0, I) 训练的模型并从截断正态中采样 z(其中超出范围的值被重新采样以落在该范围内)立即为 IS 和 FID 提供了提升。我们称之为截断 Trick,提高单个样本质量,但代价是整体样本多样性的减少。

该技术允许对 G 生成样本的质量和多样性之间的进行细粒度、事后权衡。我们的一些更大的模型不适合截断,在输入截断噪声时产生饱和伪影。为了解决这个问题,我们试图通过提高 G 平滑度来提高对截断噪声的可容忍性,以便 z 的完整空间映射到良好的输出样本。为此,我们转向正交正则化。这种正则化往往过于严格,因此我们探索了几种可以放松约束的变体,同时仍然可以赋予我们的模型所需的平滑度。我们发现效果最好的是从正则化中删除对角线项,旨在最小化过滤器之间的成对余弦相似度,但不限制它们的范数:

范数.png

其中 1 表示所有元素为 1 的矩阵。我们设置 β 为 10−4,这个小的附加惩罚足以提高我们的模型服从截断的可能性。

架构细节

在 BigGAN 模型中,使用 ResNet GAN架构。在 G 中使用单个共享类嵌入,并为潜在向量 z (skip-z) 跳过连接。潜在向量 z 沿其通道维度均分(在我们的例子中为 20-D),拼接到共享类嵌入,传递给相应的残差块作为条件向量。每个 block 的条件向量被线性投影,为 block 的 BatchNorm 层产生每个样本的增益和偏差。偏差投影以零为中心,增益投影以 1 为中心。

图片

图片

图片

图片

总结

  1. 使用更大的 BatchSize,跨设备的批归一化

  2. 将噪声向量分成几个部分分别拼接共享的类嵌入,经过线性映射,输入到 G 的不同层中

  3. 使用截断分布的采样策略,对 D 加入额外的正交惩罚项

  4. 使用了正交初始化

Code

权重初始化

 def init_weights(self):
    self.param_count = 0
    for module in self.modules():
      if (isinstance(module, nn.Conv2d) 
          or isinstance(module, nn.Linear) 
          or isinstance(module, nn.Embedding)):
        if self.init == 'ortho':
          init.orthogonal_(module.weight)

G BatchNorm

class bn(nn.Module):
  def __init__(self, output_size,  eps=1e-5, momentum=0.1,
                cross_replica=False, mybn=False):
    super(bn, self).__init__()
    self.output_size= output_size
    # Prepare gain and bias layers
    self.gain = P(torch.ones(output_size), requires_grad=True)
    self.bias = P(torch.zeros(output_size), requires_grad=True)
    # epsilon to avoid dividing by 0
    self.eps = eps
    # Momentum
    self.momentum = momentum
    # Use cross-replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    
    if self.cross_replica:
      self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)    
    elif mybn:
      self.bn = myBN(output_size, self.eps, self.momentum)
     # Register buffers if neither of the above
    else:     
      self.register_buffer('stored_mean', torch.zeros(output_size))
      self.register_buffer('stored_var',  torch.ones(output_size))
    
  def forward(self, x, y=None):
    if self.cross_replica or self.mybn:
      gain = self.gain.view(1,-1,1,1)
      bias = self.bias.view(1,-1,1,1)
      return self.bn(x, gain=gain, bias=bias)
    else:
      return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
                          self.bias, self.training, self.momentum, self.eps)

D Spectral Norm

# Spectral normalization base class 
class SN(object):
  def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
    # Number of power iterations per step
    self.num_itrs = num_itrs
    # Number of singular values
    self.num_svs = num_svs
    # Transposed?
    self.transpose = transpose
    # Epsilon value for avoiding divide-by-0
    self.eps = eps
    # Register a singular vector for each sv
    for i in range(self.num_svs):
      self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
      self.register_buffer('sv%d' % i, torch.ones(1))
  
  # Singular vectors (u side)
  @property
  def u(self):
    return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]

  # Singular values; 
  # note that these buffers are just for logging and are not used in training. 
  @property
  def sv(self):
   return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
   
  # Compute the spectrally-normalized weight
  def W_(self):
    W_mat = self.weight.view(self.weight.size(0), -1)
    if self.transpose:
      W_mat = W_mat.t()
    # Apply num_itrs power iterations
    for _ in range(self.num_itrs):
      svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 
    # Update the svs
    if self.training:
      with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
        for i, sv in enumerate(svs):
          self.sv[i][:] = sv     
    return self.weight / svs[0]


# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1,
             padding=0, dilation=1, groups=1, bias=True, 
             num_svs=1, num_itrs=1, eps=1e-12):
    nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, 
                     padding, dilation, groups, bias)
    SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)    
  def forward(self, x):
    return F.conv2d(x, self.W_(), self.bias, self.stride, 
                    self.padding, self.dilation, self.groups)


# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
  def __init__(self, in_features, out_features, bias=True,
               num_svs=1, num_itrs=1, eps=1e-12):
    nn.Linear.__init__(self, in_features, out_features, bias)
    SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
  def forward(self, x):
    return F.linear(x, self.W_(), self.bias)


# Embedding layer with spectral norm
# We use num_embeddings as the dim instead of embedding_dim here
# for convenience sake
class SNEmbedding(nn.Embedding, SN):
  def __init__(self, num_embeddings, embedding_dim, padding_idx=None, 
               max_norm=None, norm_type=2, scale_grad_by_freq=False,
               sparse=False, _weight=None,
               num_svs=1, num_itrs=1, eps=1e-12):
    nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
                          max_norm, norm_type, scale_grad_by_freq, 
                          sparse, _weight)
    SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
  def forward(self, x):
    return F.embedding(x, self.W_())

loss

# Hinge Loss
def loss_hinge_dis(dis_fake, dis_real):
  loss_real = torch.mean(F.relu(1. - dis_real))
  loss_fake = torch.mean(F.relu(1. + dis_fake))
  return loss_real, loss_fake

def loss_hinge_gen(dis_fake):
  loss = -torch.mean(dis_fake)
  return loss

# Default to hinge loss
generator_loss = loss_hinge_gen
discriminator_loss = loss_hinge_dis

# Optionally apply ortho reg in D
if config['D_ortho'] > 0.0:
    # Debug print to indicate we're using ortho reg in D.
    print('using modified ortho reg in D')
    utils.ortho(D, config['D_ortho'])

    D.optim.step()

    
if config['G_ortho'] > 0.0:
    print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G
    # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
    utils.ortho(G, config['G_ortho'], 
                blacklist=[param for param in G.shared.parameters()])
    G.optim.step()

# Apply modified ortho reg to a model
# This function is an optimized version that directly computes the gradient,
# instead of computing and then differentiating the loss.
def ortho(model, strength=1e-4, blacklist=[]):
    with torch.no_grad():
        for param in model.parameters():
            # Only apply this to parameters with at least 2 axes, and not in the blacklist
            if len(param.shape) < 2 or any([param is item for item in blacklist]):
                continue
            w = param.view(param.shape[0], -1)
            grad = (2 * torch.mm(torch.mm(w, w.t())
                                 * (1. - torch.eye(w.shape[0]device=w.device)), w))
            param.grad.data += strength * grad.view(param.shape)

测试采样

  if config['sample_trunc_curves']:
    start, step, end = [float(item) for item in config['sample_trunc_curves'].split('_')]
    print('Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end))
    for var in np.arange(start, end + step, step):     
      z_.var = var
      # Optionally comment this out if you want to run with standing stats
      # accumulated at one z variance setting
      if config['accumulate_stats']:
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                    config['num_standing_accumulations'])
      get_metrics()

参考链接:

arxiv.org/abs/1809.11…

github.com/ajbrock/Big…


AI众包项目推荐

咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)