在上一小节中,我们分析了Mean Field Theory Variational Inference,通过平均假设来得到变分推断的理论,是一种classical VI,我们可以将其看成Coordinate Ascend。而另一种方法是Stochastic Gradient Variational Inference (SGVI)。
对于隐变量参数z和数据集x。z⟶x是Generative Model,也就是p(x∣z)和p(x,z),这个过程也被我们称为Decoder。x⟶z是Inference Model,这个过程被我们称为Encoder,表达关系也就是p(z∣x)。
SGVI参数规范
我们这节的主题就是Stochastic Gradient Variational Inference (SGVI),参数的更新方法为:
θ(t+1)=θ(t)+λ(t)∇L(q)
其中,q(z∣x)被我们简化表示为q(z),我们令q(z)是一个固定形式的概率分布,ϕ为这个分布的参数,那么我们将把这个概率写成qϕ(z)。
那么,我们需要对原等式中的表达形式进行更新,
ELBO=Eqϕ(z)[logpθ(x(i),z)−logqϕ(z)]=L(ϕ)
而,
logpθ(x(i))=ELBO+KL(q∣∣p)≥L(ϕ)
而求解目标也转换成了:
p^=argmaxϕL(ϕ)
SGVI的梯度推导
∇ϕL(ϕ)===∇ϕEqϕ[logpθ(x(i),z)−logqϕ]∇ϕ∫qϕ[logpθ(x(i),z)−logqϕ]dz∫∇ϕqϕ[logpθ(x(i),z)−logqϕ]dz+∫qϕ∇ϕ[logpθ(x(i),z)−logqϕ]dz
我们把这个等式拆成两个部分,其中:
∫∇ϕqϕ[logpθ(x(i),z)−logqϕ]dz为第一个部分;
∫qϕ∇ϕ[logpθ(x(i),z)−logqϕ]dz为第二个部分。
关于第二部分的求解
第二部分比较好求,所以我们才首先求第二部分的,哈哈哈!因为logpθ(x(i),z)与ϕ无关。
2=======∫qϕ∇ϕ[logpθ(x(i),z)−logqϕ]dz−∫qϕ∇ϕlogqϕdz−∫qϕqϕ1∇ϕqϕdz−∫∇ϕqϕdz−∇ϕ∫qϕdz−∇ϕ10
关于第一部分的求解
在这里我们用到了一个小trick,这个trick在公式(6)的推导中,我们使用过的。那就是∇ϕqϕ=qϕ∇ϕlogqϕ。所以,我们代入到第一项中可以得到:
1===∫∇ϕqϕ[logpθ(x(i),z)−logqϕ]dz∫qϕ∇ϕlogqϕ[logpθ(x(i),z)−logqϕ]dzEqϕ[∇ϕlogqϕlogpθ(x(i),z)−logqϕ]
那么,我们可以得到:
∇ϕL(ϕ)=Eqϕ[∇ϕlogqϕlogpθ(x(i),z)−logqϕ]
那么如何求这个期望呢?我们采用的是蒙特卡罗采样法,假设zl∼qϕ(z) l=1,2,⋯,L,那么有:
∇ϕL(ϕ)≈L1l=1∑L∇ϕlogqϕ(z(l))[logpθ(x(i),z)−logqϕ(z(l))]
由于第二部分的结果为0,所以第一部分的解就是最终的解。但是,这样的求法有什么样的问题呢?因为我们在采样的过程中,很有可能采到qϕ(z)⟶0的点,对于log函数来说,limx⟶0logx=∞,那么梯度的变化会非常的剧烈,非常的不稳定。对于这样的High Variance的问题,根本没有办法求解。实际上,我们可以通过计算得到这个方差的解析解,它确实是一个很大的值。事实上,这里的梯度的方差这么的大,而ϕ^⟶q(z)也有误差,误差叠加,直接爆炸,根本没有办法用。也就是不会work,那么我们如何解决这个问题?
Variance Reduction
这里采用了一种比较常见的方差缩减方法,称为Reparameterization Trick,也就是对qϕ做一些简化。
我们怎么可以较好的解决这个问题?如果我们可以得到一个确定的解p(ϵ),就会变得比较简单。因为z来自于qϕ(z∣x),我们就想办法将z中的随机变量给解放出来。也就是使用一个转换z=gϕ(ϵ,x(i)),其中ϵ∼p(ϵ)。那么这样做,有什么好处呢?原来的 ∇ϕEqϕ[⋅] 将转换为 Ep(ϵ)[∇ϕ(⋅)] ,那么不在是连续的关于 ϕ 的采样,这样可以有效的降低方差。并且,z 是一个关于 ϵ 的函数,我们将随机性转移到了 ϵ ,那么问题就可以简化为:
z∼qϕ(z∣x(i))⟶ϵ∼p(ϵ)
而且,这里还需要引入一个等式,那就是:
∣qϕ(z∣x(i))dz∣=∣p(ϵ)dϵ∣
为什么呢?我们直观性的理解一下,∫qϕ(z∣x(i))dz=∫p(ϵ)dϵ=1,并且qϕ(z∣x(i))和p(ϵ)之间存在一个变换关系。
那么,我们将改写∇ϕL(ϕ):
∇ϕL(ϕ)========∇ϕEqϕ[logpθ(x(i),z)−logqϕ]∇ϕ∫[logpθ(x(i),z)−logqϕ]qϕdz∇ϕ∫[logpθ(x(i),z)−logqϕ]p(ϵ)dϵ∇ϕEp(ϵ)[logpθ(x(i),z)−logqϕ]Ep(ϵ)∇ϕ[(logpθ(x(i),z)−logqϕ)]Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕz]Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕz]Ep(ϵ)∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕgϕ(ϵ,x(i))]
那么我们的问题就这样愉快的解决了,p(ϵ)的采样与ϕ无关,然后对先求关于z的梯度,然后再求关于ϕ的梯度,那么这三者之间就互相隔离开了。最后,我们再对结果进行采样,ϵ(l)∼p(ϵ),l=1,2,⋯,L:
∇ϕL(ϕ)≈L1i=1∑L∇z[(logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕgϕ(ϵ,x(i))]
其中z⟵gϕ(ϵ(i),x(i))。而SGVI为:
ϕ(t+1)⟶ϕ(t)+λ(t)∇ϕL(ϕ)
小结
那么SGVI,可以简要的表述为:我们定义分布为qϕ(Z∣X),ϕ为参数,参数的更新方法为:
ϕ(t+1)⟶ϕ(t)+λ(t)∇ϕL(ϕ)
∇ϕL(ϕ)为:
∇ϕL(ϕ)≈L1i=1∑L∇z[logpθ(x(i),z)−logqϕ(z∣x(i)))∇ϕgϕ(ϵ,x(i))]
本文由mdnice多平台发布