1. 背景
首先我们有一批X的数据样本{X1,…,Xn},我们想要根据这些样本估计出X的分布P(X),得到分布之后我们就能生成分布里所有的样本点了,这就是生成模型的终极目标。
但是直接求出这个分布是非常困难的,一种方法是利用一个隐变量来辅助估计目标分布,就像解几何题时做辅助线一样。
p(x,z)=p(z∣x)p(x)(1)
z就是我们引入的隐变量。
2.相关知识
2.1 用蒙特卡洛模拟估计期望
在VAE中蒙特卡洛模拟主要用来求一个分布的期望。
如果已知某个分布的概率密度函数p(x),那么x的期望是
E[x]=∫xp(x)dx(2)
更一般的,我们可以写出
Ex∼p(x)[f(x)]=∫f(x)p(x)dx(3)
那么如果我们不知道分布p(x),但我们知道大量的这个分布的样本,我们要怎么求出分布的期望呢?
答案是蒙特卡洛模拟
Ex∼p(x)[f(x)]=n1t=1∑nf(xi)xi∼p(x)(4)
xi是分布p(x)的样本,我们拥有的样本越多,期望估计的就越准确。
2.2 KL散度
KL散度是一个用来度量两个概率分布p(x)和q(x)之间差异的方法
KL(p(x)∣∣q(x))=∫p(x)lnq(x)p(x)dx=Ex∼p(x)[lnq(x)p(x)](5)
KL散度是非负的且当p(x)=q(x)时KL散度等于0,这一点的证明用到了变分法,也是VAE中的V的由来。
3. VAE理论推导
3.1 先验分布
回顾公式1
p(x,z)=p(z∣x)p(x)(1)
虽然我们引入了隐变量,但我们对于分布p了解的非常少,并不知道如何求p(x∣z)和p(x),于是我们又引入一个先验分布q,利用q来表示p。
我们对q可谓是知根知底,操作也方便了很多。举个可能不恰当的例子,假设我想要学英语,但我对英语一无所知,也不知道从何入手。但我很了解中文,那么如果我能找到一个将中文转换为英文的方式(比如查字典),就能相对容易的学会英文。q就是中文,p就是英文。
高斯分布就是一个常用且好用的先验分布,因为它有很多优秀的性质,比如对称性。
q(x)就可以写成
q(x,z)=∫q(z∣x)q(x)dz(6)
我们的目的是希望q(x,z)接近p(x,z) 之后就可以用q(z∣x)的样本生成q(x,z)的样本进而生成p(x,z)的样本。
我们使用KL散度来使q(x,z)接近p(x,z)
3.2 目标函数
KL(p(x,z)∣∣q(x,z))=∫∫p(x,z)lnq(x,z)p(x,z)dzdx=∫∫p(x)p(x∣z)lnq(x,z)p(x)p(z∣x)dz=∫p(x)[∫p(z∣x)lnq(x,z)p(x)p(z∣x)dz]dx=∫p(x)[∫p(z∣x)lnp(x)+p(z∣x)lnq(x,z)p(z∣x)dz]dx=∫p(x)lnp(x)[∫p(z∣x)dz]dx+∫p(x)[∫p(z∣x)lnq(x,z)p(z∣x)dz]dx=Ex∼p(x)[lnp(x)]+Ex∼p(x)[∫p(z∣x)lnq(x,z)p(z∣x)dz](7)
Ex∼p(x)[lnp(x)] 是一个常数,尽管我们不知道它是什么,但他一定存在且不变。
因此最小化Ex∼p(x)[∫p(z∣x)lnq(x,z)p(z∣x)dz]
就是最小化KL(p(x,z)∣∣q(x,z))
这就是我们的目标函数。接下来就是怎么计算Ex∼p(x)[∫p(z∣x)lnq(x,z)p(z∣x)dz]
3.3 推导VAE目标函数
因为q(x,z)=q(x∣z)q(z),所以有
L=Ex∼p(x)[∫p(z∣x)lnq(x,z)p(z∣x)dz]=Ex∼p(x)[−∫p(z∣x)lnq(x∣z)dz+∫p(z∣x)lnq(z)p(z∣x)dz]=Ex∼p(x)[−Ez∼p(z∣x)[lnq(x∣z)]+KL(p(z∣x)∣∣q(z))](8)
我们目前得到了公式上的目标函数,L越大越好。
但q(z),q(x∣z),p(z∣x)怎么获得呢?答案就是神经网络。
3.3.1 KL散度项计算方法
我们假设q(z)∼N(0,1)是标准高斯分布,p(z∣x)∼N(μ(x),σ2(x))也是高斯分布(但不一定是标准高斯),其均值和方差是x的函数。
p(z∣x)=2πσ2(x)1e−21∣∣σ(x)z−μ(x)∣∣2(9)
两个高斯分布的KL散度是有现成的推导公式的,我们直接拿来就能得到结果。
KL(p(z∣x)∣∣q(z))]=21(μ2(x)+σ2(x)−ln(σ2(x))−1)(10)
3.3.2 −Ez∼p(z∣x)[lnq(x∣z)] 计算方法
利用2.1中提到的蒙特卡洛方法可以知道
−Ez∼p(z∣x)[lnq(x∣z)]=−n1t=1∑nlnq(x∣z)zi∼p(z∣x)(11)
这里n=1,因为训练需要迭代多次,只要训练steps够多,就能够满足蒙特卡洛模拟的要求。
假设q(x∣z)∼N(μ^(z),σ^2(z))
q(x∣z)=2πσ^2(z)1e−21∣∣σ^(z)x−μ^(z)∣∣2(12)
那么
−lnq(x∣z)=21∣∣σ^(x)x−μ^(z)∣∣2+ln2π+21ln(σ^2(z))(13)
如果认为方差是固定的常数,那么能得到
−lnq(x∣z)∼2σ21∣∣x−μ^(z)∣∣2
这就是MSE损失函数的由来。
上面的μ^(z),就是VAE中decoder,μ(x),σ(x)就是VAE中的encoder
Reference
Auto-Encoding Variational Bayes
科学空间-变分自编码器(二):从贝叶斯观点出发