这是我参与8月更文挑战的第15天,活动详情查看:8月更文挑战
详解变分自编码器——VAE
VAE全称(Variational Auto-Encoder)即变分自编码器。是一个生成模型。
了解VAE之间,我们先简单了解一下自编码器,也就是常说的Auto-Encoder 。
Auto-Encoder 包括一个编码器(Encoder)和一个解码器(Decoder)。其结构如下:
中间的这层code也称embedding。
VAE的目标
先假设一个隐变量Z的分布,构建一个从Z到目标数据X的模型,即构建X = g ( Z ) X=g(Z) X = g ( Z ) ,使得学出来的目标数据与真实数据的概率分布相近。与GAN基本一致,GAN学的也是概率分布。
模型结构
VAE的结构图(图源自苏老师的博客,侵删)如下:
VAE对每一个样本X k X_k X k 匹配一个高斯分布,隐变量Z就是从高斯分布中采样得到的。对K个样本来说,每个样本的高斯分布假设为N ( μ k , σ k 2 ) \mathcal N(\mu_k,\sigma_k^2) N ( μ k , σ k 2 ) ,问题就在于如何拟合这些分布。
VAE构建两个神经网络来进行拟合均值与方差。即μ k = f 1 ( X k ) , l o g σ k 2 = f 2 ( X k ) \mu_k=f_1(X_k),log\sigma_k^2=f_2(X_k) μ k = f 1 ( X k ) , l o g σ k 2 = f 2 ( X k ) ,拟合l o g σ k 2 log\sigma_k^2 l o g σ k 2 的原因是这样无需加激活函数。
此外,VAE让每个高斯分布尽可能地趋于标准高斯分布N ( 0 , 1 ) \mathcal N(0,1) N ( 0 , 1 ) 。这拟合过程中的误差损失则是采用KL散度作为计算。
下面做详细推导。
原理推导
其实,VAE与同为生成模型的GMM(高斯混合模型)也有很相似,实际上VAE可看成是GMM的一个distributed representation的版本。我们知道,GMM是有限个高斯分布的隐变量 z z z 的混合,而VAE可看成是无穷个隐变量 z z z 的混合,注意,VAE中的 z z z 可以是高斯也可以是非高斯的。只不过一般用的比较多的是高斯的。
原始样本数据 x x x 的概率分布:
P ( x ) = ∫ Z P ( x ) P ( x ∣ z ) d z P(x)=\int_Z P(x)P(x|z)dz P ( x ) = ∫ Z P ( x ) P ( x ∣ z ) d z
我们假设 z z z 服从标准高斯分布,先验分布 P ( x ∣ z ) P(x|z) P ( x ∣ z ) 是高斯的,即 x ∣ z ∼ N ( μ ( z ) , σ ( z ) ) x|z \sim N(\mu(z),\sigma(z)) x ∣ z ∼ N ( μ ( z ) , σ ( z )) 。μ ( z ) 、 σ ( z ) \mu(z)、\sigma(z) μ ( z ) 、 σ ( z ) 是两个函数, 分别是z z z 对应的高斯分布的均值和方差(如下图),则 P ( x ) P(x) P ( x ) 就是在积分域上所有高斯分布的累加。
由于 P ( z ) P(z) P ( z ) 是已知的,P ( x ∣ z ) P(x|z) P ( x ∣ z ) 未知,所以求解问题实际上就是求μ , σ \mu,\sigma μ , σ 这两个函数。我们最开始的目标是求解P ( x ) P(x) P ( x ) ,且我们希望P ( x ) P(x) P ( x ) 越大越好,这等价于求解关于 x x x 最大对数似然:
L = ∑ x l o g P ( x ) L=\sum_x logP(x) L = x ∑ l o g P ( x )
而 l o g P ( x ) logP(x) l o g P ( x ) 可变换为:
l o g P ( x ) = ∫ z q ( z ∣ x ) l o g P ( x ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) q ( z ∣ x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z \begin{aligned}
logP(x)&=\int_z q(z|x)logP(x)dz \\
&=\int_z q(z|x)log(\dfrac{P(z,x)}{P(z|x)})dz \\
&=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)}\dfrac{q(z|x)}{P(z|x)})dz\\
&=\int_z q(z|x)log(\dfrac{P(z,x)}{q(z|x)})dz+ \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz\\
&=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz + \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz
\end{aligned} l o g P ( x ) = ∫ z q ( z ∣ x ) l o g P ( x ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z ∣ x ) P ( z , x ) ) d z = ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z , x ) P ( z ∣ x ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z , x ) ) d z + ∫ z q ( z ∣ x ) l o g ( P ( z ∣ x ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( x ∣ z ) P ( z ) ) d z + ∫ z q ( z ∣ x ) l o g ( P ( z ∣ x ) q ( z ∣ x ) ) d z
到这里我们发现,第二项 ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z \int_z q(z|x)log(\dfrac{q(z|x)}{P(z|x)})dz ∫ z q ( z ∣ x ) l o g ( P ( z ∣ x ) q ( z ∣ x ) ) d z 其实就是 q q q 和P P P 的KL散度,即 K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) KL(q(z|x)\;||\;P(z|x)) K L ( q ( z ∣ x ) ∣∣ P ( z ∣ x )) ,因为KL散度是大于等于0的,
所以上式进一步可写成:
l o g P ( x ) ≥ ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z logP(x)\geq \int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz l o g P ( x ) ≥ ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( x ∣ z ) P ( z ) ) d z
这样我们就找到了一个下界(lower bound),也就是式子的右项,即
L b = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z L_b=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz L b = ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( x ∣ z ) P ( z ) ) d z
原式也可表示成:
l o g P ( x ) = L b + K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) logP(x)=L_b+KL(q(z|x)\;||\;P(z|x)) l o g P ( x ) = L b + K L ( q ( z ∣ x ) ∣∣ P ( z ∣ x ))
为了让 l o g P ( x ) logP(x) l o g P ( x ) 越大,我们目的就是要最大化它的这个下界。
推到这里,可能会有个疑问:为什么要引入q ( z ∣ x ) q(z|x) q ( z ∣ x ) (这里的q ( z ∣ x ) q(z|x) q ( z ∣ x ) 可以是任何分布)?
实际上,因为后验分布 P ( z ∣ x ) P(z|x) P ( z ∣ x ) 很难求(intractable),所以才用 q ( z ∣ x ) q(z|x) q ( z ∣ x ) 来逼近这个后验分布。在优化的过程中我们发现,首先 q ( z ∣ x ) q(z|x) q ( z ∣ x ) 跟 l o g P ( x ) logP(x) l o g P ( x ) 是完全没有关系的,l o g P ( x ) logP(x) l o g P ( x ) 只跟 P ( z ∣ x ) P(z|x) P ( z ∣ x ) 有关,调节 q ( z ∣ x ) q(z|x) q ( z ∣ x ) 是不会影响似然也就是 l o g P ( x ) logP(x) l o g P ( x ) 的。所以,当我们固定住 P ( x ∣ z ) P(x|z) P ( x ∣ z ) 时,调节 q ( z ∣ x ) q(z|x) q ( z ∣ x ) 最大化下界 L b L_b L b ,KL则越小。当 q ( z ∣ x ) q(z|x) q ( z ∣ x ) 与不断逼近后验分布P ( z ∣ x ) P(z|x) P ( z ∣ x ) 时,KL散度趋于为0,l o g P ( x ) logP(x) l o g P ( x ) 就和 L b L_b L b 等价。所以最大化 l o g P ( x ) logP(x) l o g P ( x ) 就等价于最大化 L b L_b L b 。
回顾 L b L_b L b ,
L b = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) + E q ( z ∣ x ) [ l o g ( P ( x ∣ z ) ) ] \begin{aligned}
L_b&=\int_z q(z|x)log(\dfrac{P(x|z)P(z)}{q(z|x)})dz \\
&=\int_z q(z|x)log(\dfrac{P(z)}{q(z|x)})dz+\int_z q(z|x)logP(x|z)dz \\
&=-KL(q(z|x)\;||\;P(z)) + \int_z q(z|x)logP(x|z)dz \\
&=-KL(q(z|x)\;||\;P(z)) + E_{q(z|x)}[log(P(x|z))]
\end{aligned} L b = ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( x ∣ z ) P ( z ) ) d z = ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ) ) d z + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣∣ P ( z )) + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z = − K L ( q ( z ∣ x ) ∣∣ P ( z )) + E q ( z ∣ x ) [ l o g ( P ( x ∣ z ))]
显然,最大化 L b L_b L b 就是等价于最小化 K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) KL(q(z|x)\;||\;P(z)) K L ( q ( z ∣ x ) ∣∣ P ( z )) 和最大化 E q ( z ∣ x ) [ l o g ( P ( x ∣ z ) ) ] E_{q(z|x)}[log(P(x|z))] E q ( z ∣ x ) [ l o g ( P ( x ∣ z ))] 。
第一项,最小化KL散度 。我们前面已假设了 P ( z ) P(z) P ( z ) 是服从标准高斯分布的,且 q ( z ∣ x ) q(z|x) q ( z ∣ x ) 是服从高斯分布 N ( μ , σ 2 ) \mathcal N(\mu,\sigma^2) N ( μ , σ 2 ) ,于是代入计算可得:
K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) = K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = ∫ 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 ( l o g e − ( x − μ ) 2 2 σ 2 / 2 π σ 2 e − x 2 2 / 2 π ) d x . . . 化简得到 = 1 2 1 2 π σ 2 ∫ e − ( x − μ ) 2 2 σ 2 ( − l o g σ 2 + x 2 − ( x − μ ) 2 σ 2 ) d x = 1 2 ∫ 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 ( − l o g σ 2 + x 2 − ( x − μ ) 2 σ 2 ) d x \begin{aligned}
KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=&\int\dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left( log\dfrac{e^{\frac{-(x-\mu)^2}{2\sigma^2}}/\sqrt{2\pi\sigma^2}}{ e^{\frac{-x^2}{2}}/\sqrt{2\pi} } \right)dx
\\&...\text{化简得到}
\\=&\dfrac{1}{2}\dfrac{1}{\sqrt{2\pi\sigma^2}}\int e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx
\\=&\dfrac{1}{2}\int \dfrac{1}{\sqrt{2\pi\sigma^2}} e^{\frac{-(x-\mu)^2}{2\sigma^2}} \left(-log\sigma^2 +x^2-\dfrac{(x-\mu)^2}{\sigma^2} \right)dx
\end{aligned} K L ( q ( z ∣ x ) ∣∣ P ( z )) = K L ( N ( μ , σ 2 ) ∣∣ N ( 0 , 1 )) = = = ∫ 2 π σ 2 1 e 2 σ 2 − ( x − μ ) 2 ⎝ ⎛ l o g e 2 − x 2 / 2 π e 2 σ 2 − ( x − μ ) 2 / 2 π σ 2 ⎠ ⎞ d x ... 化简得到 2 1 2 π σ 2 1 ∫ e 2 σ 2 − ( x − μ ) 2 ( − l o g σ 2 + x 2 − σ 2 ( x − μ ) 2 ) d x 2 1 ∫ 2 π σ 2 1 e 2 σ 2 − ( x − μ ) 2 ( − l o g σ 2 + x 2 − σ 2 ( x − μ ) 2 ) d x
对上式中的积分进一步求解,1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 \dfrac{1}{\sqrt{2\pi\sigma^2}}e^{\frac{-(x-\mu)^2}{2\sigma^2}} 2 π σ 2 1 e 2 σ 2 − ( x − μ ) 2 实际就是概率密度 f ( x ) f(x) f ( x ) ,而概率密度函数的积分就是1,所以积分第一项等于− l o g σ 2 -log\sigma^2 − l o g σ 2 ;而又因为高斯分布的二阶矩就是 E ( X 2 ) = ∫ x 2 f ( x ) d x = μ 2 + σ 2 E(X^2)=\int x^2f(x)dx=\mu^2+\sigma^2 E ( X 2 ) = ∫ x 2 f ( x ) d x = μ 2 + σ 2 ,正好对应积分第二项。又根据方差的定义可知σ = ∫ ( x − μ ) d x \sigma=\int (x-\mu)dx σ = ∫ ( x − μ ) d x ,所以积分第三项为 − 1 -1 − 1 。
最终化简得到的结果如下:
K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) = K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) = 1 2 ( − l o g σ 2 + μ 2 + σ 2 − 1 ) KL(q(z|x)\;||\;P(z))=KL(\mathcal N(\mu,\sigma^2)\;||\;\mathcal N(0,1))=\dfrac{1}{2}(-log\sigma^2+\mu^2+\sigma^2-1) K L ( q ( z ∣ x ) ∣∣ P ( z )) = K L ( N ( μ , σ 2 ) ∣∣ N ( 0 , 1 )) = 2 1 ( − l o g σ 2 + μ 2 + σ 2 − 1 )
第二项,最大化期望 。也就是表明在给定 q ( z ∣ x ) q(z|x) q ( z ∣ x ) (编码器输出)的情况下 P ( x ∣ z ) P(x|z) P ( x ∣ z ) (解码器输出)的值尽可能高。具体来讲,第一步,利用encoder的神经网络计算出均值与方差,从中采样得到 z z z ,这一过程就对应式子中的 q ( z ∣ x ) q(z|x) q ( z ∣ x ) ;第二步,利用decoder的NN计算 z z z 的均值方差,让均值(或也考虑方差)越接近 x x x ,则产生 x x x 的几率 l o g P ( x ∣ z ) logP(x|z) l o g P ( x ∣ z ) 越大,对应于式子中的最大化 l o g P ( x ∣ z ) logP(x|z) l o g P ( x ∣ z ) 这一部分。
推导至此完毕。
重参数技巧
最后模型在实现的时候,有一个重参数技巧,就是我们想从高斯分布 N ( μ , σ 2 ) \mathcal N(\mu,\sigma^2) N ( μ , σ 2 ) 中采样Z时,其实是相当于从 N ( 0 , 1 ) \mathcal N(0,1) N ( 0 , 1 ) 中采样一个 ϵ \epsilon ϵ ,然后再来计算 Z = μ + ϵ × σ Z=\mu+\epsilon\times\sigma Z = μ + ϵ × σ 。这么做的原因是,采样这个操作是不可导的,而采样的结果是可导的,这样做个参数变换,Z = μ + ϵ × σ Z=\mu+\epsilon\times\sigma Z = μ + ϵ × σ 这个就可以参与梯度下降,模型就可以训练了。
参考
苏剑林:变分自编码器(一):原来是这么一回事
李宏毅老师 Machine Learning (2017,秋,台湾大学) 国语