详解变分自编码器——VAE

5,021 阅读4分钟

这是我参与8月更文挑战的第15天,活动详情查看:8月更文挑战

详解变分自编码器——VAE

VAE全称(Variational Auto-Encoder)即变分自编码器。是一个生成模型。

了解VAE之间,我们先简单了解一下自编码器,也就是常说的Auto-Encoder

Auto-Encoder包括一个编码器(Encoder)和一个解码器(Decoder)。其结构如下:

Auto-Encoder

中间的这层code也称embedding。

VAE的目标

先假设一个隐变量Z的分布,构建一个从Z到目标数据X的模型,即构建X=g(Z)X=g(Z),使得学出来的目标数据与真实数据的概率分布相近。与GAN基本一致,GAN学的也是概率分布。

模型结构

VAE的结构图(图源自苏老师的博客,侵删)如下:

VAE结构

VAE对每一个样本XkX_k匹配一个高斯分布,隐变量Z就是从高斯分布中采样得到的。对K个样本来说,每个样本的高斯分布假设为N(μk,σk2)\mathcal N(\mu_k,\sigma_k^2),问题就在于如何拟合这些分布。

VAE构建两个神经网络来进行拟合均值与方差。即μk=f1(Xk),logσk2=f2(Xk)\mu_k=f_1(X_k),log\sigma_k^2=f_2(X_k),拟合logσk2log\sigma_k^2的原因是这样无需加激活函数。

此外,VAE让每个高斯分布尽可能地趋于标准高斯分布N(0,1)\mathcal N(0,1)。这拟合过程中的误差损失则是采用KL散度作为计算。

下面做详细推导。

原理推导

其实,VAE与同为生成模型的GMM(高斯混合模型)也有很相似,实际上VAE可看成是GMM的一个distributed representation的版本。我们知道,GMM是有限个高斯分布的隐变量 zz 的混合,而VAE可看成是无穷个隐变量 zz 的混合,注意,VAE中的 zz 可以是高斯也可以是非高斯的。只不过一般用的比较多的是高斯的。

原始样本数据 xx 的概率分布:

P(x)=ZP(x)P(xz)dzP(x)=\int_Z P(x)P(x|z)dz

我们假设 zz 服从标准高斯分布,先验分布 P(xz)P(x|z) 是高斯的,即 xzN(μ(z),σ(z))x|z \sim N(\mu(z),\sigma(z))μ(z)σ(z)\mu(z)、\sigma(z)是两个函数, 分别是zz对应的高斯分布的均值和方差(如下图),则 P(x)P(x) 就是在积分域上所有高斯分布的累加。

在这里插入图片描述

由于 P(z)P(z) 是已知的,P(xz)P(x|z) 未知,所以求解问题实际上就是求μ,σ\mu,\sigma这两个函数。我们最开始的目标是求解P(x)P(x),且我们希望P(x)P(x)越大越好,这等价于求解关于 xx 最大对数似然:

L=xlogP(x)L=\sum_x logP(x)

logP(x)logP(x) 可变换为:

logP(x)=zq(zx)logP(x)dz=zq(zx)log(P(z,x)P(zx))dz=zq(zx)log(P(z,x)q(zx)q(zx)P(zx))dz=zq(zx)log(P(z,x)q(zx))dz+zq(zx)log(q(zx)P(zx))dz=zq(zx)log(P(xz)P(z)q(zx))dz+zq(zx)log(q(zx)P(zx))dz\begin{aligned} logP(x)&=\int_z q(z|x)logP(x)dz \\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{P(z|x)})dz \\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)}\dfrac{q(z|x)}{P(z|x)})dz\\ &=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)})dz+ \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz\\ &=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz + \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz \end{aligned}

到这里我们发现,第二项 zq(zx)log(q(zx)P(zx))dz\int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz 其实就是 qqPP 的KL散度,即 KL(q(zx)    P(zx))KL(q(z|x)\;||\;P(z|x)),因为KL散度是大于等于0的,

所以上式进一步可写成:

logP(x)zq(zx)log(P(xz)P(z)q(zx))dzlogP(x)\geq \int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz

这样我们就找到了一个下界(lower bound),也就是式子的右项,即

Lb=zq(zx)log(P(xz)P(z)q(zx))dzL_b=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz

原式也可表示成:

logP(x)=Lb+KL(q(zx)    P(zx))logP(x)=L_b+KL(q(z|x)\;||\;P(z|x))

为了让 logP(x)logP(x) 越大,我们目的就是要最大化它的这个下界。

推到这里,可能会有个疑问:为什么要引入q(zx)q(z|x)(这里的q(zx)q(z|x)可以是任何分布)?

实际上,因为后验分布 P(zx)P(z|x) 很难求(intractable),所以才用 q(zx)q(z|x) 来逼近这个后验分布。在优化的过程中我们发现,首先 q(zx)q(z|x)logP(x)logP(x) 是完全没有关系的,logP(x)logP(x) 只跟 P(zx)P(z|x) 有关,调节 q(zx)q(z|x) 是不会影响似然也就是 logP(x)logP(x) 的。所以,当我们固定住 P(xz)P(x|z) 时,调节 q(zx)q(z|x) 最大化下界 LbL_b,KL则越小。当 q(zx)q(z|x)与不断逼近后验分布P(zx)P(z|x)时,KL散度趋于为0,logP(x)logP(x)就和 LbL_b 等价。所以最大化 logP(x)logP(x) 就等价于最大化 LbL_b

在这里插入图片描述

回顾 LbL_b,

Lb=zq(zx)log(P(xz)P(z)q(zx))dz=zq(zx)log(P(z)q(zx))dz+zq(zx)logP(xz)dz=KL(q(zx)    P(z))+zq(zx)logP(xz)dz=KL(q(zx)    P(z))+Eq(zx)[log(P(xz))]\begin{aligned} L_b&=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz \\ &=\int_z q(z|x)log(\dfrac{P(z)}{q(z|x)})dz+\int_z q(z|x)logP(x|z)dz \\ &=-KL(q(z|x)\;||\;P(z)) + \int_z q(z|x)logP(x|z)dz \\ &=-KL(q(z|x)\;||\;P(z)) + E_{q(z|x)}[log(P(x|z))] \end{aligned}

显然,最大化 LbL_b 就是等价于最小化 KL(q(zx)    P(z))KL(q(z|x)\;||\;P(z)) 和最大化 Eq(zx)[log(P(xz))]E_{q(z|x)}[log(P(x|z))]

第一项,最小化KL散度。我们前面已假设了 P(z)P(z) 是服从标准高斯分布的,且 q(zx)q(z|x) 是服从高斯分布 N(μ,σ2)\mathcal N(\mu,\sigma^2) ,于是代入计算可得:

KL(q(zx)    P(z))=KL(N(μ,σ2)    N(0,1))=12πσ2e(xμ)22σ2(loge(xμ)22σ2/2πσ2ex22/2π)dx...化简得到=1212πσ2e(xμ)22σ2(logσ2+x2(xμ)2σ2)dx=1212πσ2e(xμ)22σ2(logσ2+x2(xμ)2σ2)dx\begin{aligned} KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=&\int\dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left( log\dfrac{e^{\frac{-(x-\mu)^2}{2\sigma^2}}/\sqrt{2\pi\sigma^2}}{ e^{\frac{-x^2}{2}}/\sqrt{2\pi} } \right)dx \\&...\text{化简得到} \\=&\dfrac{1}{2}\dfrac{1}{\sqrt{2\pi\sigma^2}}\int e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx \\=&\dfrac{1}{2}\int \dfrac{1}{\sqrt{2\pi\sigma^2}} e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx \end{aligned}

对上式中的积分进一步求解,12πσ2e(xμ)22σ2\dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}}实际就是概率密度 f(x)f(x),而概率密度函数的积分就是1,所以积分第一项等于logσ2-log\sigma^2;而又因为高斯分布的二阶矩就是 E(X2)=x2f(x)dx=μ2+σ2E(X^2)=\int x^2f(x)dx=\mu^2+\sigma^2,正好对应积分第二项。又根据方差的定义可知σ=(xμ)dx\sigma=\int (x-\mu)dx,所以积分第三项为 1-1

最终化简得到的结果如下:

KL(q(zx)    P(z))=KL(N(μ,σ2)    N(0,1))=12(logσ2+μ2+σ21)KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=\dfrac{1}{2}(-log\sigma^2+\mu^2+\sigma^2-1)

第二项,最大化期望。也就是表明在给定 q(zx)q(z|x)(编码器输出)的情况下 P(xz)P(x|z)(解码器输出)的值尽可能高。具体来讲,第一步,利用encoder的神经网络计算出均值与方差,从中采样得到 zz,这一过程就对应式子中的 q(zx)q(z|x);第二步,利用decoder的NN计算 zz 的均值方差,让均值(或也考虑方差)越接近 xx ,则产生 xx 的几率 logP(xz)logP(x|z) 越大,对应于式子中的最大化 logP(xz)logP(x|z) 这一部分。

在这里插入图片描述

推导至此完毕。

重参数技巧

最后模型在实现的时候,有一个重参数技巧,就是我们想从高斯分布 N(μ,σ2)\mathcal N(\mu,\sigma^2) 中采样Z时,其实是相当于从 N(0,1)\mathcal N(0,1) 中采样一个 ϵ\epsilon,然后再来计算 Z=μ+ϵ×σZ=\mu+\epsilon\times\sigma。这么做的原因是,采样这个操作是不可导的,而采样的结果是可导的,这样做个参数变换,Z=μ+ϵ×σZ=\mu+\epsilon\times\sigma 这个就可以参与梯度下降,模型就可以训练了。

参考

  1. 苏剑林:变分自编码器(一):原来是这么一回事
  2. 李宏毅老师 Machine Learning (2017,秋,台湾大学) 国语