VAE理论推导

110 阅读2分钟

1. 背景

首先我们有一批X的数据样本{X1,…,Xn},我们想要根据这些样本估计出X的分布P(X),得到分布之后我们就能生成分布里所有的样本点了,这就是生成模型的终极目标。

但是直接求出这个分布是非常困难的,一种方法是利用一个隐变量来辅助估计目标分布,就像解几何题时做辅助线一样。

p(x,z)=p(zx)p(x)(1)p(x,z)=p(z|x)p(x) \tag{1}

zz就是我们引入的隐变量。

2.相关知识

2.1 用蒙特卡洛模拟估计期望

在VAE中蒙特卡洛模拟主要用来求一个分布的期望。

如果已知某个分布的概率密度函数p(x)p(x),那么xx的期望是

E[x]=xp(x)dx(2)E[x] = \int x p(x)dx \tag{2}

更一般的,我们可以写出

Exp(x)[f(x)]=f(x)p(x)dx(3)E_{x\sim p(x)}[f(x)]=\int f(x)p(x)dx \tag{3}

那么如果我们不知道分布p(x)p(x),但我们知道大量的这个分布的样本,我们要怎么求出分布的期望呢?

答案是蒙特卡洛模拟

Exp(x)[f(x)]=1nt=1nf(xi)xip(x)(4)E_{x\sim p(x)}[f(x)]=\frac{1}{n}\sum_{t=1}^nf(x_i) \\x_i \sim p(x) \tag{4}

xix_i是分布p(x)p(x)的样本,我们拥有的样本越多,期望估计的就越准确。

2.2 KL散度

KL散度是一个用来度量两个概率分布p(x)和q(x)之间差异的方法

KL(p(x)q(x))=p(x)lnp(x)q(x)dx=Exp(x)[lnp(x)q(x)](5)KL(p(x)||q(x))=\int p(x)ln\frac{p(x)}{q(x)}dx = E_{x\sim p(x)}[ln\frac{p(x)}{q(x)}] \tag{5}

KL散度是非负的且当p(x)=q(x)p(x)=q(x)时KL散度等于0,这一点的证明用到了变分法,也是VAE中的V的由来。

3. VAE理论推导

3.1 先验分布

回顾公式1

p(x,z)=p(zx)p(x)(1)p(x,z)=p(z|x)p(x) \tag{1}

虽然我们引入了隐变量,但我们对于分布pp了解的非常少,并不知道如何求p(xz)p(x|z)p(x)p(x),于是我们又引入一个先验分布qq,利用qq来表示pp

我们对qq可谓是知根知底,操作也方便了很多。举个可能不恰当的例子,假设我想要学英语,但我对英语一无所知,也不知道从何入手。但我很了解中文,那么如果我能找到一个将中文转换为英文的方式(比如查字典),就能相对容易的学会英文。qq就是中文,pp就是英文。

高斯分布就是一个常用且好用的先验分布,因为它有很多优秀的性质,比如对称性。

q(x)q(x)就可以写成

q(x,z)=q(zx)q(x)dz(6)q(x,z) = \int q(z|x)q(x)dz \tag{6}

我们的目的是希望q(x,z)q(x,z)接近p(x,z)p(x,z) 之后就可以用q(zx)q(z|x)的样本生成q(x,z)q(x,z)的样本进而生成p(x,z)p(x,z)的样本。 我们使用KL散度来使q(x,z)q(x,z)接近p(x,z)p(x,z)

3.2 目标函数

KL(p(x,z)q(x,z))=p(x,z)lnp(x,z)q(x,z)dzdx=p(x)p(xz)lnp(x)p(zx)q(x,z)dz=p(x)[p(zx)lnp(x)p(zx)q(x,z)dz]dx=p(x)[p(zx)lnp(x)+p(zx)lnp(zx)q(x,z)dz]dx=p(x)lnp(x)[p(zx)dz]dx+p(x)[p(zx)lnp(zx)q(x,z)dz]dx=Exp(x)[lnp(x)]+Exp(x)[p(zx)lnp(zx)q(x,z)dz](7)\begin{aligned} KL(p(x,z)||q(x,z))&=\int \int p(x,z)ln\frac{p(x,z)}{q(x,z)}dzdx \\ &=\int\int p(x)p(x|z)ln\frac{p(x)p(z|x)}{q(x,z)}dz\\ &= \int p(x)\big[ \int p(z|x)ln\frac{p(x)p(z|x)}{q(x,z)}dz \big]dx\\ &= \int p(x)\big[ \int p(z|x) lnp(x) +p(z|x)ln\frac{p(z|x)}{q(x,z)}dz \big]dx\\ &=\int p(x)lnp(x)\big[\int p(z|x)dz\big]dx + \int p(x)\big[\int p(z|x) ln\frac{p(z|x)}{q(x,z)}dz\big]dx\\ &=E_{x\sim p(x)}[lnp(x)] + E_{x \sim p(x)}\big[\int p(z|x) ln\frac{p(z|x)}{q(x,z)}dz\big] \tag{7} \end{aligned}

Exp(x)[lnp(x)]E_{x\sim p(x)}[lnp(x)] 是一个常数,尽管我们不知道它是什么,但他一定存在且不变。 因此最小化Exp(x)[p(zx)lnp(zx)q(x,z)dz]E_{x \sim p(x)}\big[\int p(z|x) ln\frac{p(z|x)}{q(x,z)}dz\big] 就是最小化KL(p(x,z)q(x,z))KL(p(x,z)||q(x,z))

这就是我们的目标函数。接下来就是怎么计算Exp(x)[p(zx)lnp(zx)q(x,z)dz]E_{x \sim p(x)}\big[\int p(z|x) ln\frac{p(z|x)}{q(x,z)}dz\big]

3.3 推导VAE目标函数

因为q(x,z)=q(xz)q(z)q(x,z)=q(x|z)q(z),所以有

L=Exp(x)[p(zx)lnp(zx)q(x,z)dz]=Exp(x)[p(zx)lnq(xz)dz+p(zx)lnp(zx)q(z)dz]=Exp(x)[Ezp(zx)[lnq(xz)]+KL(p(zx)q(z))](8)\begin{aligned} L &= E_{x \sim p(x)}\big[\int p(z|x) ln\frac{p(z|x)}{q(x,z)}dz\big]\\ &=E_{x\sim p(x)}\big[ -\int p(z|x)lnq(x|z)dz +\int p(z|x)ln\frac{p(z|x)}{q(z)}dz \big]\\ &=E_{x \sim p(x)}\big[-E_{z \sim p(z|x)}\big[lnq(x|z)\big] +KL(p(z|x)||q(z))\big] \tag{8} \end{aligned}

我们目前得到了公式上的目标函数,L越大越好。

q(z),q(xz),p(zx)q(z),q(x|z),p(z|x)怎么获得呢?答案就是神经网络。

3.3.1 KL散度项计算方法

我们假设q(z)N(0,1)q(z)\sim N(0,1)是标准高斯分布,p(zx)N(μ(x),σ2(x))p(z|x) \sim N(\mu(x),\sigma^2(x))也是高斯分布(但不一定是标准高斯),其均值和方差是x的函数。

p(zx)=12πσ2(x)e12zμ(x)σ(x)2(9)p(z|x) = \frac{1}{\sqrt{2\pi\sigma^2(x)}}e^{-\frac{1}{2}||\frac{z-\mu(x)}{\sigma(x)}||^2} \tag{9}

两个高斯分布的KL散度是有现成的推导公式的,我们直接拿来就能得到结果。

KL(p(zx)q(z))]=12(μ2(x)+σ2(x)ln(σ2(x))1)(10)KL(p(z|x)||q(z))\big] = \frac{1}{2}(\mu^2(x)+\sigma^2(x)-ln(\sigma^2(x))-1) \tag{10}

3.3.2 Ezp(zx)[lnq(xz)]-E_{z \sim p(z|x)}\big[lnq(x|z)\big] 计算方法

利用2.1中提到的蒙特卡洛方法可以知道

Ezp(zx)[lnq(xz)]=1nt=1nlnq(xz)zip(zx)(11)-E_{z \sim p(z|x)}\big[lnq(x|z)\big] = -\frac{1}{n}\sum_{t=1}^nlnq(x|z)\\ z_i \sim p(z|x) \tag{11}

这里n=1,因为训练需要迭代多次,只要训练steps够多,就能够满足蒙特卡洛模拟的要求。

假设q(xz)N(μ^(z),σ^2(z))q(x|z)\sim N(\hat\mu(z),\hat\sigma^2(z))

q(xz)=12πσ^2(z)e12xμ^(z)σ^(z)2(12)q(x|z) = \frac{1}{\sqrt{2\pi\hat\sigma^2(z)}}e^{-\frac{1}{2}||\frac{x-\hat\mu(z)}{\hat\sigma(z)}||^2} \tag{12}

那么

lnq(xz)=12xμ^(z)σ^(x)2+ln2π+12ln(σ^2(z))(13)-lnq(x|z) = \frac{1}{2}||\frac{x-\hat\mu(z)}{\hat\sigma(x)}||^2+ln2\pi+\frac{1}{2} ln(\hat\sigma^2(z)) \tag{13}

如果认为方差是固定的常数,那么能得到

lnq(xz)12σ2xμ^(z)2-lnq(x|z) \sim \frac{1}{2\sigma^2}||x-\hat\mu(z)||^2

这就是MSE损失函数的由来。

上面的μ^(z),\hat\mu(z),就是VAE中decoder,μ(x),σ(x)\mu(x),\sigma(x)就是VAE中的encoder

Reference

Auto-Encoding Variational Bayes

科学空间-变分自编码器(二):从贝叶斯观点出发