GAN 系列——PROGAN

226 阅读6分钟

简介

描述了一种生成对抗网络的新训练方法。关键的想法是逐步增加生成器和鉴别器:从低分辨率开始,添加新的层,随着训练的进行,这些层可以生成越来越精细的细节。这既加快了训练速度,又大大稳定了训练,使我们能够生成前所未有质量的图像,例如 1024x1024 的 CELEBA 图像。还提出了一种简单的方法来增加生成的图像的多样性,并在无监督的 CIFAR10 中获得 8.80 的创纪录的  inception 分数。此外,我们描述了几个实现细节,这些细节对于防止生成器和鉴别器之间的不健康竞争非常重要。最后,提出了一种新的评估 GAN 结果的指标,包括图像质量和多样性。作为另一项贡献,构建了 CELEBA 数据集的更高质量版本。

渐进的训练过程

从低分辨率图像开始,然后通过向网络添加层来逐步提高分辨率,如图所示。这种增量性质允许训练过程中先发现图像整体分布结构,然后将注意力转移到越来越精细的图像细节,而不必同时学习所有分辨率的信息。

图片

使用生成器和判别器的网络是彼此的镜像,并且总是同步增长。在整个训练过程中,两个网络中的所有现有层都是可训练的。当新的层被添加到网络中时,会将它们平滑地淡入,如图所示。这避免了对已经训练有素的较小分辨率层的突然冲击。当将生成器(G)和鉴别器(D)的分辨率加倍时,在过渡期间(b),我们将以较高分辨率的层视为残差块,其权重 从 0 线性增加到 1。这里,2× 和 0.5× 分别表示使用最近邻滤波 (将一个像素复制4份) 和平均池化将图像分辨率加倍和减半。toRGB 表示将特征向量投影到 RGB 颜色的特征层,而fromRGB 则相反;两者都使用1×1卷积。当训练鉴别器时,我们输入经过缩小以匹配网络当前分辨率的真实图像。在分辨率转换期间,我们在真实图像的两个分辨率之间进行插值,类似于生成器输出组合两个分辨率的方式。

图片

观察到渐进式训练有几个好处。首先,较小图像的生成基本上更稳定,因为类别信息更少,模式更少。通过一点一点地提高分辨率,与建立从潜在向量到 1024x1024 个像素点的映射的最终目标相比,我们不断地提出一个相对简单的问题。在实践中,它充分稳定了训练,使我们能够使用 WGAN-GP 损失,甚至 LSGAN 损失。另一个好处是减少了训练时间。

使用小批量标准差增加多样性

简化后的解决方案既没有可学习的参数,也没有新的超参数。首先计算小批量上每个空间位置中每个特征的标准差。然后,对所有特征和空间位置的这些标准差进行平均,以获得单个值。复制该值,并将其拼接到小批次中所有空间位置,从而生成一个额外的(常数的)特征。

归一化生成器和判别器

图片

生成器和判别器的架构

图片

训练设置

  1. 每一分辨率阶段,用 800K 张真实图像训练判别器,然后过渡到下一分辨率期间,再用 800K 张真实图像训练判别器。

  2. 生成器和判别器的最后一层都是线性层。

  3. 只在生成器 Conv 3x3 层后,使用逐像素归一化。

  4. 在运行时使用特定于层的常数来缩放权重。

  5. 向判别器的末端注入跨小批量标准差。

  6. 用 EMA 更新生成器的权重,衰减系数 0.999。

  7. 在判别器加入一项小权重的损失,防止判别器输出离 0 太远。

  8. Adam,学习率 0.001,b1=0, b2=0.99

Code

PixelwiseNorm

class PixelwiseNorm(torch.nn.Module):
    """
    ------------------------------------------------------------------------------------
    Pixelwise feature vector normalization.
    reference:
    https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120
    ------------------------------------------------------------------------------------
    """
    def __init__(self):
        super(PixelwiseNorm, self).__init__()

    @staticmethod
    def forward(x: Tensor, alpha: float = 1e-8) -> Tensor:
        y = x.pow(2.0).mean(dim=1, keepdim=True).add(alpha).sqrt()  # [N,1,H,W]
        y = x / y  # normalize the input x volume
        return y

MinibatchStdDev

class MinibatchStdDev(torch.nn.Module):
    """
    Minibatch standard deviation layer for the discriminator
    Args:
        group_size: batch 内进行分组数
    """

    def __init__(self, group_size: int = 4) -> None:
        super(MinibatchStdDev, self).__init__()
        self.group_size = group_size

    def extra_repr(self) -> str:
        return f"group_size={self.group_size}"

    def forward(self, x: Tensor, alpha: float = 1e-8) -> Tensor:
        """
        Args:
            x: 判别器最后下采样的输出
            alpha: 保证数值稳定的极小数
        Returns: y => x appended with standard deviation constant map
        """
        batch_size, channels, height, width = x.shape
        if batch_size > self.group_size:
            assert batch_size % self.group_size == 0, (
                f"batch_size {batch_size} should be "
                f"perfectly divisible by group_size {self.group_size}"
            )
            group_size = self.group_size
        else:
            group_size = batch_size

        # 将 batch 内分为多个组,计算每个组的标准差
        y = torch.reshape(x, [group_size, -1, channels, height, width])

        # [G x M x C x H x W] 每组减去均值
        y = y - y.mean(dim=0, keepdim=True)

        # [M x C x H x W] 每组计算标准差
        y = torch.sqrt(y.square().mean(dim=0, keepdim=False) + alpha)

        # [M x 1 x 1 x 1]  特征维和像素空间取平均
        y = y.mean(dim=[123], keepdim=True)

        # [B x 1 x H x W]  复制每组的标准差到每组的像素空间
        y = y.repeat(group_size, 1, height, width)

        # [B x (C + 1) x H x W]  拼接到原输出的特征维上,作为常量特征
        y = torch.cat([x, y], 1)

        return y


EMA 更新生成器的权重

def update_average(model_tgt, model_src, beta):
    """
    function to calculate the Exponential moving averages for the Generator weights
    This function updates the exponential average weights based on the current training
    Args:
        model_tgt: target model
        model_src: source model
        beta: value of decay beta
    Returns: None (updates the target model)
    """

    with torch.no_grad():
        param_dict_src = dict(model_src.named_parameters())

        for p_name, p_tgt in model_tgt.named_parameters():
            p_src = param_dict_src[p_name]
            assert p_src is not p_tgt
            p_tgt.copy_(beta * p_tgt + (1.0 - beta) * p_src)

self.gen_shadow = copy.deepcopy(self.gen)
# 初始化
update_average(self.gen_shadow, self.gen, beta=0)

# 更新

优化 D

def progressive_downsample_batch(self, real_batch, depth, alpha):
        down_sample_factor = int(2 ** (self.depth - depth))
        prior_downsample_factor = int(2 ** (self.depth - depth + 1))

        ds_real_samples = avg_pool2d(
            real_batch, kernel_size=down_sample_factor, stride=down_sample_factor
        )

        if depth > 2:
            prior_ds_real_samples = interpolate(
                avg_pool2d(
                    real_batch,
                    kernel_size=prior_downsample_factor,
                    stride=prior_downsample_factor,
                ),
                scale_factor=2,
            )
        else:
            prior_ds_real_samples = ds_real_samples

        # 真实样本是 ds_real_samples 和 prior_ds_real_samples 的线性组合
        real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)

        return real_samples

def _gradient_penalty(
        dis: Discriminator,
        real_samples: Tensor,
        fake_samples: Tensor,
        depth: int,
        alpha: float,
        reg_lambda: float = 10,
        labels: Optional[Tensor] = None,
    ) -> Tensor:
        """
        private helper for calculating the gradient penalty
        Args:
            dis: the discriminator used for computing the penalty
            real_samples: real samples
            fake_samples: fake samples
            depth: current depth in the optimization
            alpha: current alpha for fade-in
            reg_lambda: regularisation lambda
        Returns: computed gradient penalty
        """
        batch_size = real_samples.shape[0]

        # 生成随机混合系数
        epsilon = torch.rand((batch_size, 111)).to(real_samples.device)

        # 混合真实样本和假样本
        merged = epsilon * real_samples + ((1 - epsilon) * fake_samples)
        merged.requires_grad_(True)

        # 前向运算
        if labels is not None:
            assert dis.conditional, "labels passed to an unconditional discriminator"
            op = dis(merged, depth, alpha, labels)
        else:
            op = dis(merged, depth, alpha)

        # 计算梯度
        gradient = torch.autograd.grad(
            outputs=op,
            inputs=merged,
            grad_outputs=torch.ones_like(op),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        # 计算惩罚项
        gradient = gradient.view(gradient.shape[0], -1)
        penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()
        
        return penalty
    
def dis_loss(
        self,
        discriminator: Discriminator,
        real_samples: Tensor,
        fake_samples: Tensor,
        depth: int,
        alpha: float,
        labels: Optional[Tensor] = None,
    ) -> Tensor:
        if labels is not None:
            assert discriminator.conditional, "labels passed to an unconditional dis"
            real_scores = discriminator(real_samples, depth, alpha, labels)
            fake_scores = discriminator(fake_samples, depth, alpha, labels)
        else:
            real_scores = discriminator(real_samples, depth, alpha)
            fake_scores = discriminator(fake_samples, depth, alpha)
        loss = (
            torch.mean(fake_scores)
            - torch.mean(real_scores)
            + (self.drift * torch.mean(real_scores ** 2))
        )

        # 计算 WGAN-GP (gradient penalty)
        gp = self._gradient_penalty(
            discriminator, real_samples, fake_samples, depth, alpha, labels=labels
        )
        loss += gp

        return loss
    
real_samples = self.progressive_downsample_batch(real_batch, depth, alpha)
fake_samples = self.gen(noise, depth, alpha).detach()
dis_loss = loss.dis_loss(
            self.dis, real_samples, fake_samples, depth, alpha, labels=labels
        )

优化 G

def gen_loss(
        self,
        discriminator: Discriminator,
        _: Tensor,
        fake_samples: Tensor,
        depthint,
        alphafloat,
        labels: Optional[Tensor] = None,
    ) -> Tensor:
        if labels is not None:
            assert discriminator.conditional, "labels passed to an unconditional dis"
            fake_scores = discriminator(fake_samples, depth, alpha, labels)
        else:
            fake_scores = discriminator(fake_samples, depth, alpha)
        return -torch.mean(fake_scores)

real_samples = self.progressive_downsample_batch(real_batch, depth, alpha)
fake_samples = self.gen(noise, depth, alpha)

gen_loss = loss.gen_loss(
self.dis, real_samples, fake_samples, depth, alpha, labels=labels
)

参考链接:github.com/akanimax/pr…

arxiv.org/abs/1710.10…


ONE MORE THING

咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)