变分自编码器(VAE)详解与实现(采用TensorFlow2实现)

7,513 阅读7分钟

「这是我参与11月更文挑战的第9天,活动详情查看:2021最后一次更文挑战

VAE介绍

变分自编码器 (Variational Auto-Encoders, VAE) 属于生成模型家族。VAE 的生成器能够利用连续潜在空间的矢量产生有意义的输出。通过潜在矢量探索解码器输出的可能属性。

在 GAN 中,重点在于如何得出近似输入分布的模型。 VAE尝试对可解耦的连续潜在空间中的输入分布进行建模。

在 VAE 中,重点在于潜编码的变分推理。因此,VAE 为潜在变量的学习和有效贝叶斯推理提供了合适的框架。

在结构上,VAE 与自编码器相似。它也由编码器(也称为识别或推理模型)和解码器(也称为生成模型)组成。 VAE 和自编码器都试图在学习潜矢量的同时重建输入数据。但是,与自编码器不同,VAE 的潜在空间是连续的,并且解码器本身被用作生成模型。

VAE原理

在生成模型中,使用神经网络来逼近输入的真实分布:

xPθ(x)(1)x \sim P_θ(x) \qquad(1)

其中,θθ 表示模型参数。

在机器学习中,为了执行特定的推理,希望找到输入 xx 和潜变量 zz 之间的联合分布 Pθ(x,z)P_θ(x,z)。潜变量是对可从输入中观察到的某些属性进行编码。如在人脸数据中,这些可能是面部表情,发型,头发颜色,性别等。

Pθ(x,z)P_θ(x,z) 实际上是输入数据及其属性的分布。Pθ(x)P_θ(x) 可以从边缘分布计算:

Pθ(x)=Pθ(x,z)dz(2)P_θ(x)=\int P_θ(x,z)dz \qquad(2)

换句话说,考虑所有可能的属性,最终得到描述输入的分布。在人脸数据中,利用包含面部表情,发型,头发颜色和性别在内的特征,可以恢复描述人脸数据的分布。

问题在于该方程式没有解析形式或有效的估计量。因此,通过神经网络进行优化是不可行的。

使用贝叶斯定理,可以找到方程式(2)的替代表达式:

Pθ(x)=Pθ(xz)P(z)dz(3)P_θ(x)=\int P_θ(x|z)P(z)dz \qquad(3)

其中,P(z)P(z)zz 的先验分布。它不以任何观察为条件。如果 zz 是离散的并且 Pθ(xz)P_θ(x|z) 是高斯分布,则Pθ(x)P_θ(x)是高斯分布的混合。如果zz是连续的,则高斯分布 Pθ(x)P_θ(x) 无法预估。

在实践中,如果尝试在没有合适的损失函数的情况下建立近似 Pθ(xz)P_θ(x|z) 的神经网络,它将忽略zz并得出平凡解,Pθ(xz)=Pθ(x)P_θ(x|z)=P_θ(x)。因此,公式(3)不能提供 Pθ(x)P_θ(x) 的良好估计。公式(2)也可以表示为:

Pθ(x)=Pθ(zx)P(x)dz(4)P_θ(x)=\int P_θ(z|x)P(x)dz \qquad(4)

但是,Pθ(zx)P_θ(z|x) 也难以求解。 VAE的目标是找到一个可估计的分布,该分布近似估计Pθ(zx)P_θ(z|x),即在给定输入 xx 的情况下对潜在编码 zz 的条件分布的估计。

变分推理

为了使 Pθ(zx)P_θ(z|x) 易于处理,VAE 引入了变分推断模型(编码器):

Qϕ(zx)Pθ(zx)(5)Q_\phi (z|x) \approx P_θ(z|x) \qquad(5)

Qϕ(zx)Q_\phi (z|x) 可很好地估计 Pθ(zx)P_θ(z|x)。它既可以参数化又易于处理。 可以通过深度神经网络优化参数 φφ 来近似 Qϕ(zx)Q_\phi (z|x)。 通常,将 Qϕ(zx)Q_\phi (z|x) 选择为多元高斯分布:

Qϕ(zx)=N(z;μ(x),diag(σ(x)2))(6)Q_\phi (z|x)=\mathcal N(z;\mu(x),diag(\sigma(x)^2)) \qquad(6)

均值 μ(x)\mu(x) 和标准差 σ(x)\sigma (x) 均由编码器神经网络使用输入数据计算得出。对角矩阵表示zz中的元素间是相互独立的。

VAE核心方程

推理模型 Qϕ(zx)Q_\phi (z|x) 从输入 xx 生成潜矢量 zzQϕ(zx)Q_\phi (z|x) 类似于自编码器模型中的编码器。另一方面,Pθ(xz)P_θ(x|z) 从潜码z重建输入。Pθ(xz)P_θ(x|z) 的作用类似于自编码器模型中的解码器。要估算 Pθ(x)P_θ(x),必须确定其与 Qϕ(zx)Q_\phi (z|x)Pθ(xz)P_θ(x|z) 的关系。

如果 Qϕ(zx)Q_\phi (z|x)Pθ(zx)P_θ(z|x) 的估计值,则 Kullback-Leibler(KL)散度确定这两个条件密度之间的距离:

DKL(Qϕ(zx)Pθ(zx))=EzQ[logQϕ(zx)logPθ(zx)](7)D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(z|x)] \qquad (7)

使用贝叶斯定理:

Pθ(zx)=Pθ(xz)Pθ(z)Pθ(x)(8)P_θ(z|x)=\frac{P_θ(x|z)P_θ(z)}{P_θ(x)} \qquad(8)

通过公式(8)改写公式(7),同时由于 logPθ(x)logP_θ(x) 不依赖于 zQz\sim Q

DKL(Qϕ(zx)Pθ(zx))=EzQ[logQϕ(zx)logPθ(xz)logPθ(z)]+logPθ(x)(9)D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(x|z)-logP_θ(z)] + logP_θ(x)\qquad (9)

重排上式并由:

EzQ[logQϕ(zx)logPθ(z)]=DKL(Qϕ(zx)Pθ(z))(10)\mathbb E_{z\sim Q}[logQ_\phi (z|x)-logP_θ(z)] = D_{KL}(Q_\phi (z|x) \| P_θ(z)) \qquad (10)

得到:

logPθ(x)DKL(Qϕ(zx)Pθ(zx))=EzQ[logPθ(xz)]DKL(Qϕ(zx)Pθ(z))(11)logP_θ(x)-D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) = \mathbb E_{z\sim Q}[logP_θ(x|z)] - D_{KL}(Q_\phi (z|x) \| P_θ(z))\qquad (11)

上式是 VAE 的核心。左侧项 Pθ(x)P_θ(x),它最大化地减少了 Qϕ(zx)Q_\phi (z|x) 与真实 Pθ(zx)P_θ(z|x) 之间距离的差距。对数不会改变最大值(或最小值)的位置。给定一个可以很好地估计 Pθ(zx)P_θ(z|x) 的推断模型, DKL(Qϕ(zx)Pθ(zx))D_{KL}(Q_\phi (z|x) \| P_θ(z|x)) 约为零。

右边的第一项 Pθ(zx))P_θ(z|x)) 类似于解码器,该解码器从推理模型中提取样本以重建输入。

第二项是 Qϕ(zx)Q_\phi (z|x)Pθ(z)P_θ(z) 间的 KL 距离。公式的左侧也称为变化下界(evidence lower bound, ELBO)。由于 KL 始终为正,因此 ELBO 是 logPθ(x)logP_θ(x) 的下限。通过优化神经网络的参数 φφθθ 来最大化 ELBO 意味着:

  1. DKL(Qϕ(zx)Pθ(zx))0D_{KL}(Q_\phi (z|x) \| P_θ(z|x))\to 0 或在 zz 中对属性 xx 进行编码的推理模型得到优化。
  2. 右侧的 logPθ(xz)logP_θ(x|z) 最大化,或者从潜在矢量 zz 重构 xx 时,解码器模型得到优化。

优化方式

公式的右侧具有有关 VAE 损失函数的两个重要信息。解码器项 EzQ[logPθ(xz)]\mathbb E_{z\sim Q}[logP_θ(x|z)] 表示生成器从推理模型的输出中获取 zz 个样本以重构输入。最大化该项意味着将重建损失 LR\mathcal L_R 最小化。如果图像(数据)分布假定为高斯分布,则可以使用 MSE。

如果每个像素(数据)都被认为是伯努利分布,那么损失函数就是一个二元交叉熵。

第二项 DKL(Qϕ(zx)Pθ(z))- D_{KL}(Q_\phi (z|x) \| P_θ(z)),由于 QϕQ_\phi 是高斯分布。通常 Pθ(z)=P(z)=N(0,1)P_θ(z)=P(z)=\mathcal N(0,1),也是均值为 00 且标准偏差等于 1.0 的高斯分布。KL 项可以简化为:

DKL(Qϕ(zx)Pθ(z))=12j=0J(1+log(σj)2(μj)2(σj)2)(12)- D_{KL}(Q_\phi (z|x) \| P_θ(z))=\frac{1}{2} \sum_{j=0}^J (1+log(\sigma_j)^2-(\mu_j)^2-(\sigma_j)^2)\qquad(12)

其中 JJzz 的维数。和都是通过推理模型计算得到的关于 xx 的函数。要最大化 DKL-D_{KL} :则σj1\sigma_j \to 1μj0\mu_j \to 0P(z)=N(0,1)P(z)=\mathcal N(0,1) 的选择是由于各向同性单位高斯分布的性质,可以给定适当的函数将其变形为任意分布。

根据公式(12),KL 损失 LKL\mathcal L_{KL}DKLD_{KL}。 综上,VAE 损失函数定义为:

LVAE=LR+LKL(13)\mathcal L_{VAE}=\mathcal L_R + \mathcal L_{KL}\qquad (13)

给定编码器和解码器模型的情况下,在构建和训练VAE之前,还有一个问题需要解决。

重参数化技巧(Reparameterization trick)

下图左侧显示了 VAE 网络。编码器获取输入 xx,并估计潜矢量 zz 的多元高斯分布的均值 μμ 和标准差 σσ。 解码器从潜矢量 zz 采样,以将输入重构为 xx

VAE

但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。 解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:

Sample=μ+εσ(14)Sample=\mu + εσ\qquad(14)

如果 εεσσ 以矢量形式表示,则 εσεσ 是逐元素乘法。 使用公式(14),令采样好像直接来自于潜空间。 这项技术被称为重参数化技巧。 之后在输入端进行采样,可以使用熟悉的优化算法(例如SGD,Adam或RMSProp)来训练VAE网络。

VAE实现

为了便于可视化潜在编码,将 zz 的维度设置为 2。编码器仅是两层 MLP,第二层生成均值和对数方差。对数方差的使用是为了简化 KL 损失和重参数化技巧的计算。编码器的第三个输出是使用重参数化技巧进行的 zz 采样。在采样函数中,e0.5logσ2=σ2=σe^{0.5log\sigma^2}=\sqrt{\sigma^2}=\sigma,因为 σ>0σ> 0 是高斯分布的标准偏差。

解码器也是两层 MLP,它对zz的样本进行采样以近似输入。VAE 网络只是将编码器和解码器连接在一起。损失函数是重建损失和KL损失之和。使用Adam优化器。

重参数技巧

#reparameterization trick
#z = z_mean + sqrt(var) * eps
def sampling(args):
    """Reparameterization trick by sampling
    Reparameterization trick by sampling fr an isotropic unit Gaussian.
    #Arguments:
        args (tensor): mean and log of variance of Q(z|x)
    #Returns:
        z (tensor): sampled latent vector
    """
    z_mean,z_log_var = args
    batch = keras.backend.shape(z_mean)[0]
    dim = keras.backend.shape(z_mean)[1]

    epsilon = keras.backend.random_normal(shape=(batch,dim))
    return z_mean + keras.backend.exp(0.5 * z_log_var) * epsilon

加载数据与超参数

# MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

#超参数
input_shape = (original_dim,)
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50

VAE模型

#VAE model
#encoder
inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
x = keras.layers.Dense(intermediate_dim,activation='relu')(inputs)
z_mean = keras.layers.Dense(latent_dim,name='z_mean')(x)
z_log_var = keras.layers.Dense(latent_dim,name='z_log_var')(x)

z = keras.layers.Lambda(sampling,output_shape=(latent_dim,),name='z')([z_mean,z_log_var])

encoder = keras.Model(inputs,[z_mean,z_log_var,z],name='encoder')
encoder.summary()
keras.utils.plot_model(encoder,to_file='vae_mlp_encoder.png',show_shapes=True)

#decoder
latent_inputs = keras.layers.Input(shape=(latent_dim,),name='z_sampling')
x = keras.layers.Dense(intermediate_dim,activation='relu')(latent_inputs)
outputs = keras.layers.Dense(original_dim,activation='sigmoid')(x)
decoder = keras.Model(latent_inputs,outputs,name='decoder')
decoder.summary()
keras.utils.plot_model(decoder,to_file='vae_mlp_decoder.png',show_shapes=True)

outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs,outputs,name='vae_mpl')

模型训练

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load tf model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Use binary cross entropy instead of mse (default)"
    parser.add_argument("--bce", help=help_, action='store_true')
    args = parser.parse_args()
    models = (encoder, decoder)
    data = (x_test, y_test)
    
    #VAE loss = mse_loss or xent_loss + kl_loss
    if args.bce:
        reconstruction_loss = keras.losses.binary_crossentropy(inputs,outputs)
    else:
        reconstruction_loss = keras.losses.mse(inputs,outputs)
    
    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - keras.backend.square(z_mean) - keras.backend.exp(z_log_var)
    kl_loss = keras.backend.sum(kl_loss,axis=-1)
    kl_loss *= -0.5
    vae_loss = keras.backend.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')
    vae.summary()
    keras.utils.plot_model(vae,to_file='vae_mlp.png',show_shapes=True)
    save_dir = 'vae_mlp_weights'
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    if args.weights:
        filepath = os.path.join(save_dir,args.weights)
        vae = vae.load_weights(filepath)
    else:
        #train
        vae.fit(x_train,
                epochs=epochs,
                batch_size=batch_size,
                validation_data=(x_test,None))
        filepath = os.path.join(save_dir,'vae_mlp.mnist.tf')
        vae.save_weights(filepath)
    plot_results(models,data,batch_size=batch_size,model_name='vae_mlp')

测试经过训练的解码器

在训练了 VAE 网络之后,可以丢弃推理模型。为了生成新的有意义的输出,从用于生成 εε 的高斯分布中抽取样本:

解码器

效果展示

潜矢量可视化

潜矢量可视化

图片生成

图片生成