简介
描述了一种生成对抗网络的新训练方法。关键的想法是逐步增加生成器和鉴别器:从低分辨率开始,添加新的层,随着训练的进行,这些层可以生成越来越精细的细节。这既加快了训练速度,又大大稳定了训练,使我们能够生成前所未有质量的图像,例如 1024x1024 的 CELEBA 图像。还提出了一种简单的方法来增加生成的图像的多样性,并在无监督的 CIFAR10 中获得 8.80 的创纪录的 inception 分数。此外,我们描述了几个实现细节,这些细节对于防止生成器和鉴别器之间的不健康竞争非常重要。最后,提出了一种新的评估 GAN 结果的指标,包括图像质量和多样性。作为另一项贡献,构建了 CELEBA 数据集的更高质量版本。
渐进的训练过程
从低分辨率图像开始,然后通过向网络添加层来逐步提高分辨率,如图所示。这种增量性质允许训练过程中先发现图像整体分布结构,然后将注意力转移到越来越精细的图像细节,而不必同时学习所有分辨率的信息。
使用生成器和判别器的网络是彼此的镜像,并且总是同步增长。在整个训练过程中,两个网络中的所有现有层都是可训练的。当新的层被添加到网络中时,会将它们平滑地淡入,如图所示。这避免了对已经训练有素的较小分辨率层的突然冲击。当将生成器(G)和鉴别器(D)的分辨率加倍时,在过渡期间(b),我们将以较高分辨率的层视为残差块,其权重 从 0 线性增加到 1。这里,2× 和 0.5× 分别表示使用最近邻滤波 (将一个像素复制4份) 和平均池化将图像分辨率加倍和减半。toRGB 表示将特征向量投影到 RGB 颜色的特征层,而fromRGB 则相反;两者都使用1×1卷积。当训练鉴别器时,我们输入经过缩小以匹配网络当前分辨率的真实图像。在分辨率转换期间,我们在真实图像的两个分辨率之间进行插值,类似于生成器输出组合两个分辨率的方式。
观察到渐进式训练有几个好处。首先,较小图像的生成基本上更稳定,因为类别信息更少,模式更少。通过一点一点地提高分辨率,与建立从潜在向量到 1024x1024 个像素点的映射的最终目标相比,我们不断地提出一个相对简单的问题。在实践中,它充分稳定了训练,使我们能够使用 WGAN-GP 损失,甚至 LSGAN 损失。另一个好处是减少了训练时间。
使用小批量标准差增加多样性
简化后的解决方案既没有可学习的参数,也没有新的超参数。首先计算小批量上每个空间位置中每个特征的标准差。然后,对所有特征和空间位置的这些标准差进行平均,以获得单个值。复制该值,并将其拼接到小批次中所有空间位置,从而生成一个额外的(常数的)特征。
归一化生成器和判别器
生成器和判别器的架构
训练设置
-
每一分辨率阶段,用 800K 张真实图像训练判别器,然后过渡到下一分辨率期间,再用 800K 张真实图像训练判别器。
-
生成器和判别器的最后一层都是线性层。
-
只在生成器 Conv 3x3 层后,使用逐像素归一化。
-
在运行时使用特定于层的常数来缩放权重。
-
向判别器的末端注入跨小批量标准差。
-
用 EMA 更新生成器的权重,衰减系数 0.999。
-
在判别器加入一项小权重的损失,防止判别器输出离 0 太远。
-
Adam,学习率 0.001,b1=0, b2=0.99
Code
PixelwiseNorm
class PixelwiseNorm(torch.nn.Module):
"""
------------------------------------------------------------------------------------
Pixelwise feature vector normalization.
reference:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120
------------------------------------------------------------------------------------
"""
def __init__(self):
super(PixelwiseNorm, self).__init__()
@staticmethod
def forward(x: Tensor, alpha: float = 1e-8) -> Tensor:
y = x.pow(2.0).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N,1,H,W]
y = x / y # normalize the input x volume
return y
MinibatchStdDev
class MinibatchStdDev(torch.nn.Module):
"""
Minibatch standard deviation layer for the discriminator
Args:
group_size: batch 内进行分组数
"""
def __init__(self, group_size: int = 4) -> None:
super(MinibatchStdDev, self).__init__()
self.group_size = group_size
def extra_repr(self) -> str:
return f"group_size={self.group_size}"
def forward(self, x: Tensor, alpha: float = 1e-8) -> Tensor:
"""
Args:
x: 判别器最后下采样的输出
alpha: 保证数值稳定的极小数
Returns: y => x appended with standard deviation constant map
"""
batch_size, channels, height, width = x.shape
if batch_size > self.group_size:
assert batch_size % self.group_size == 0, (
f"batch_size {batch_size} should be "
f"perfectly divisible by group_size {self.group_size}"
)
group_size = self.group_size
else:
group_size = batch_size
# 将 batch 内分为多个组,计算每个组的标准差
y = torch.reshape(x, [group_size, -1, channels, height, width])
# [G x M x C x H x W] 每组减去均值
y = y - y.mean(dim=0, keepdim=True)
# [M x C x H x W] 每组计算标准差
y = torch.sqrt(y.square().mean(dim=0, keepdim=False) + alpha)
# [M x 1 x 1 x 1] 特征维和像素空间取平均
y = y.mean(dim=[1, 2, 3], keepdim=True)
# [B x 1 x H x W] 复制每组的标准差到每组的像素空间
y = y.repeat(group_size, 1, height, width)
# [B x (C + 1) x H x W] 拼接到原输出的特征维上,作为常量特征
y = torch.cat([x, y], 1)
return y
EMA 更新生成器的权重
def update_average(model_tgt, model_src, beta):
"""
function to calculate the Exponential moving averages for the Generator weights
This function updates the exponential average weights based on the current training
Args:
model_tgt: target model
model_src: source model
beta: value of decay beta
Returns: None (updates the target model)
"""
with torch.no_grad():
param_dict_src = dict(model_src.named_parameters())
for p_name, p_tgt in model_tgt.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_tgt
p_tgt.copy_(beta * p_tgt + (1.0 - beta) * p_src)
self.gen_shadow = copy.deepcopy(self.gen)
# 初始化
update_average(self.gen_shadow, self.gen, beta=0)
# 更新
优化 D
def progressive_downsample_batch(self, real_batch, depth, alpha):
down_sample_factor = int(2 ** (self.depth - depth))
prior_downsample_factor = int(2 ** (self.depth - depth + 1))
ds_real_samples = avg_pool2d(
real_batch, kernel_size=down_sample_factor, stride=down_sample_factor
)
if depth > 2:
prior_ds_real_samples = interpolate(
avg_pool2d(
real_batch,
kernel_size=prior_downsample_factor,
stride=prior_downsample_factor,
),
scale_factor=2,
)
else:
prior_ds_real_samples = ds_real_samples
# 真实样本是 ds_real_samples 和 prior_ds_real_samples 的线性组合
real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)
return real_samples
def _gradient_penalty(
dis: Discriminator,
real_samples: Tensor,
fake_samples: Tensor,
depth: int,
alpha: float,
reg_lambda: float = 10,
labels: Optional[Tensor] = None,
) -> Tensor:
"""
private helper for calculating the gradient penalty
Args:
dis: the discriminator used for computing the penalty
real_samples: real samples
fake_samples: fake samples
depth: current depth in the optimization
alpha: current alpha for fade-in
reg_lambda: regularisation lambda
Returns: computed gradient penalty
"""
batch_size = real_samples.shape[0]
# 生成随机混合系数
epsilon = torch.rand((batch_size, 1, 1, 1)).to(real_samples.device)
# 混合真实样本和假样本
merged = epsilon * real_samples + ((1 - epsilon) * fake_samples)
merged.requires_grad_(True)
# 前向运算
if labels is not None:
assert dis.conditional, "labels passed to an unconditional discriminator"
op = dis(merged, depth, alpha, labels)
else:
op = dis(merged, depth, alpha)
# 计算梯度
gradient = torch.autograd.grad(
outputs=op,
inputs=merged,
grad_outputs=torch.ones_like(op),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# 计算惩罚项
gradient = gradient.view(gradient.shape[0], -1)
penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()
return penalty
def dis_loss(
self,
discriminator: Discriminator,
real_samples: Tensor,
fake_samples: Tensor,
depth: int,
alpha: float,
labels: Optional[Tensor] = None,
) -> Tensor:
if labels is not None:
assert discriminator.conditional, "labels passed to an unconditional dis"
real_scores = discriminator(real_samples, depth, alpha, labels)
fake_scores = discriminator(fake_samples, depth, alpha, labels)
else:
real_scores = discriminator(real_samples, depth, alpha)
fake_scores = discriminator(fake_samples, depth, alpha)
loss = (
torch.mean(fake_scores)
- torch.mean(real_scores)
+ (self.drift * torch.mean(real_scores ** 2))
)
# 计算 WGAN-GP (gradient penalty)
gp = self._gradient_penalty(
discriminator, real_samples, fake_samples, depth, alpha, labels=labels
)
loss += gp
return loss
real_samples = self.progressive_downsample_batch(real_batch, depth, alpha)
fake_samples = self.gen(noise, depth, alpha).detach()
dis_loss = loss.dis_loss(
self.dis, real_samples, fake_samples, depth, alpha, labels=labels
)
优化 G
def gen_loss(
self,
discriminator: Discriminator,
_: Tensor,
fake_samples: Tensor,
depth: int,
alpha: float,
labels: Optional[Tensor] = None,
) -> Tensor:
if labels is not None:
assert discriminator.conditional, "labels passed to an unconditional dis"
fake_scores = discriminator(fake_samples, depth, alpha, labels)
else:
fake_scores = discriminator(fake_samples, depth, alpha)
return -torch.mean(fake_scores)
real_samples = self.progressive_downsample_batch(real_batch, depth, alpha)
fake_samples = self.gen(noise, depth, alpha)
gen_loss = loss.gen_loss(
self.dis, real_samples, fake_samples, depth, alpha, labels=labels
)
ONE MORE THING
咪豆AI圈(Meedo)针对当前人工智能领域行业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探索(www.meedo.top/)