一、概述
当我们处理概率模型时,我们可以从两个不同的视角看待问题:频率学派的视角和贝叶斯的视角。在频率学派的视角下,我们将问题视为一个优化问题 ,假设模型的最佳参数是一个确定的常数。我们可以回顾一下线性回归(我们使用最小二乘法定义损失函数),支持向量机(最终转化为一个约束优化问题),以及EM算法(我们通过迭代求解模型的参数)。这些方法共享一个特性,那就是它们都在参数空间中寻找最优参数,因此最后都变成了优化问题。
但是,当我们从贝叶斯的视角来看待问题时,问题变成了一个积分问题 。在这种情况下,模型的参数并不是一个确定的常数,而是服从一个分布。对于一组给定的样本数据X X X ,我们需要对新的样本x ^ \hat{x} x ^ 进行评估。
那么为什么从贝叶斯角度来看就会是一个积分问题呢?现在以贝叶斯的角度来看待问题,模型的参数此时并非确定的常数,而是服从一个分布。如果已有多个样本数据记作X X X ,对于新的样本x ^ \hat{x} x ^ ,需要得到:
p ( x ^ ∣ X ) = ∫ θ p ( x ^ , θ ∣ X ) d θ = ∫ θ p ( x ^ ∣ θ , X ) p ( θ ∣ X ) d θ = x ^ 与 X 独立 ∫ θ p ( x ^ ∣ θ ) p ( θ ∣ X ) d θ = E θ ∣ X [ p ( x ^ ∣ θ ) ] p(\hat{x}|X)=\int _{\theta }p(\hat{x},\theta |X)\mathrm{d}\theta =\int _{\theta }p(\hat{x}|\theta ,X)p(\theta |X)\mathrm{d}\theta \\ \overset{\hat{x}与X独立}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |X)\mathrm{d}\theta =E_{\theta |X}[p(\hat{x}|\theta )] p ( x ^ ∣ X ) = ∫ θ p ( x ^ , θ ∣ X ) d θ = ∫ θ p ( x ^ ∣ θ , X ) p ( θ ∣ X ) d θ = x ^ 与 X 独立 ∫ θ p ( x ^ ∣ θ ) p ( θ ∣ X ) d θ = E θ ∣ X [ p ( x ^ ∣ θ )]
如果新样本和数据集独立,那么这个推断问题就是求概率分布依参数后验分布的期望。推断问题的核心是参数后验分布的求解,推断分为:
精确推断
近似推断-参数空间无法精确求解:
①确定性近似-如变分推断
②随机近似-如 MCMC,MH,Gibbs
二、公式导出
有以下数据:
X X X :observed variable
Z Z Z :latent variable + + + parameter
( X , Z ) (X, Z) ( X , Z ) :complete data
我们记 Z Z Z 为隐变量和参数的集合(注意这里和以前不太一样,这里的 Z Z Z 是隐变量+参数)。接着我们变换概率 p ( X ) p(X) p ( X ) 的形式然后引入分布 q ( Z ) q(Z) q ( Z ) ,这里的 X X X 指的是单个样本:
log p ( X ) = log p ( X , Z ) − log p ( Z ∣ X ) = log p ( X , Z ) q ( Z ) − log p ( Z ∣ X ) q ( Z ) \log p(X)=\log p(X, Z)-\log p(Z \mid X)=\log \frac{p(X, Z)}{q(Z)}-\log \frac{p(Z \mid X)}{q(Z)} log p ( X ) = log p ( X , Z ) − log p ( Z ∣ X ) = log q ( Z ) p ( X , Z ) − log q ( Z ) p ( Z ∣ X )
式子两边同时对 q ( Z ) q(Z) q ( Z ) 求积分:
左边 = ∫ Z q ( Z ) ⋅ l o g p ( X ) d Z = l o g p ( X ) ∫ Z q ( Z ) d Z = l o g p ( X ) 右边 = ∫ Z q ( Z ) l o g p ( X , Z ) q ( Z ) d Z ⏟ E L B O ( e v i d e n c e l o w e r b o u n d ) − ∫ Z q ( Z ) l o g p ( Z ∣ X ) q ( Z ) d Z ⏟ K L ( q ( Z ) ∣ ∣ p ( Z ∣ X , ) ) = L ( q ) ⏟ 变分 + K L ( q ∣ ∣ p ) ⏟ ≥ 0 左边=\int _{Z}q(Z)\cdot log\; p(X)\mathrm{d}Z=log\; p(X)\int _{Z}q(Z )\mathrm{d}Z=log\; p(X) \\ 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{Z}q(Z)log\; \frac{p(X,Z)}{q(Z)}\mathrm{d}Z}}\underset{KL(q(Z)||p(Z|X,))}{\underbrace{-\int _{Z}q(Z)log\; \frac{p(Z|X)}{q(Z)}\mathrm{d}Z}} \\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}} 左边 = ∫ Z q ( Z ) ⋅ l o g p ( X ) d Z = l o g p ( X ) ∫ Z q ( Z ) d Z = l o g p ( X ) 右边 = E L BO ( e v i d e n ce l o w er b o u n d ) ∫ Z q ( Z ) l o g q ( Z ) p ( X , Z ) d Z K L ( q ( Z ) ∣∣ p ( Z ∣ X , )) − ∫ Z q ( Z ) l o g q ( Z ) p ( Z ∣ X ) d Z = 变分 L ( q ) + ≥ 0 K L ( q ∣∣ p )
分布 q q q 是用来近似后验 p p p 的,我们的目的是找到一个分布 q q q 使得 q q q 与 p p p 最接近,也就是使 K L ( q ∥ p ) K L(q \| p) K L ( q ∥ p ) 越小越好,相当于使 L ( q ) L(q) L ( q ) 越大越好 (注意 q ( Z ) q(Z) q ( Z ) 其实指的是 q ( Z ∣ X ) q(Z \mid X) q ( Z ∣ X ) ,我们只是简写成 q ( Z ) ) q(Z)) q ( Z )) :
q ~ ( Z ) = argmax q ( Z ) L ( q ) ⇒ q ~ ( Z ) ≈ p ( Z ∣ X ) \tilde{q}(Z)=\underset{q(Z)}{\operatorname{argmax}} L(q) \Rightarrow \tilde{q}(Z) \approx p(Z \mid X) q ~ ( Z ) = q ( Z ) argmax L ( q ) ⇒ q ~ ( Z ) ≈ p ( Z ∣ X )
Z Z Z 是一个高维随机变量,在变分推断中我们对 q ( Z q(Z q ( Z ) 做以下假设(基于平均场假设 的变分推断),也就是说我们把多维变量的不同维度分为 M M M 组,组与组之间是相互独立的:
q ( Z ) = ∏ i = 1 M q i ( Z i ) q(Z)=\prod_{i=1}^M q_i\left(Z_i\right) q ( Z ) = i = 1 ∏ M q i ( Z i )
求解时我们固定 q i ( Z i ) , i ≠ j q_i\left(Z_i\right), i \neq j q i ( Z i ) , i = j 来求 q j ( Z j ) q_j\left(Z_j\right) q j ( Z j ) ,接下来将 L ( q ) L(q) L ( q ) 写作两部分:
L ( q ) = ∫ Z q ( Z ) l o g p ( X , Z ) d Z ⏟ ① − ∫ Z q ( Z ) l o g q ( Z ) d Z ⏟ ② L(q)=\underset{①}{\underbrace{\int _{Z}q(Z)log\; p(X,Z)\mathrm{d}Z}}-\underset{②}{\underbrace{\int _{Z}q(Z)log\; q(Z)\mathrm{d}Z}} L ( q ) = ① ∫ Z q ( Z ) l o g p ( X , Z ) d Z − ② ∫ Z q ( Z ) l o g q ( Z ) d Z
对于①:
① = ∫ Z ∏ i = 1 M q i ( Z i ) l o g p ( X , Z ) d Z 1 d Z 2 ⋯ d Z M = ∫ Z j q j ( Z j ) ( ∫ Z − Z j ∏ i ≠ j M q i ( Z i ) l o g p ( X , Z ) d Z 1 d Z 2 ⋯ d Z M ( i ≠ j ) ) ⏟ ∫ Z − Z j l o g p ( X , Z ) ∏ i ≠ j M q i ( Z i ) d Z i d Z j = ∫ Z j q j ( Z j ) ⋅ E ∏ i ≠ j M q i ( Z i ) [ l o g p ( X , Z ) ] ⋅ d Z j ①=\int _{Z}\prod_{i=1}^{M}q_{i}(Z_{i})log\; p(X,Z)\mathrm{d}Z_{1}\mathrm{d}Z_{2}\cdots \mathrm{d}Z_{M}\\ =\int _{Z_{j}}q_{j}(Z_{j})\underset{\int _{Z-Z_{j}}log\; p(X,Z)\prod_{i\neq j}^{M}q_{i}(Z_{i})\mathrm{d}Z_{i}}{\underbrace{\left (\int _{Z-Z_{j}}\prod_{i\neq j}^{M}q_{i}(Z_{i})log\; p(X,Z)\underset{(i\neq j)}{\mathrm{d}Z_{1}\mathrm{d}Z_{2}\cdots \mathrm{d}Z_{M}}\right )}}\mathrm{d}Z_{j}\\ =\int _{Z_{j}}q_{j}(Z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(Z_{i})}[log\; p(X,Z)]\cdot \mathrm{d}Z_{j} ① = ∫ Z i = 1 ∏ M q i ( Z i ) l o g p ( X , Z ) d Z 1 d Z 2 ⋯ d Z M = ∫ Z j q j ( Z j ) ∫ Z − Z j l o g p ( X , Z ) ∏ i = j M q i ( Z i ) d Z i ⎝ ⎛ ∫ Z − Z j i = j ∏ M q i ( Z i ) l o g p ( X , Z ) ( i = j ) d Z 1 d Z 2 ⋯ d Z M ⎠ ⎞ d Z j = ∫ Z j q j ( Z j ) ⋅ E ∏ i = j M q i ( Z i ) [ l o g p ( X , Z )] ⋅ d Z j
对于②:
② = ∫ Z q ( Z ) l o g q ( Z ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) ∑ i = 1 M l o g q i ( Z i ) d Z = ∫ Z ∏ i = 1 M q i ( Z i ) [ l o g q 1 ( Z 1 ) + l o g q 2 ( Z 2 ) + ⋯ + l o g q M ( Z M ) ] d Z 其中 ∫ Z ∏ i = 1 M q i ( Z i ) l o g q 1 ( Z 1 ) d Z = ∫ Z 1 Z 2 ⋯ Z M q 1 ( Z 1 ) q 2 ( Z 2 ) ⋯ q M ( Z M ) ⋅ l o g q 1 ( Z 1 ) d Z 1 d Z 2 ⋯ d Z M = ∫ Z 1 q 1 ( Z 1 ) l o g q 1 ( Z 1 ) d Z 1 ⋅ ∫ Z 2 q 2 ( Z 2 ) d Z 2 ⏟ = 1 ⋅ ∫ Z 3 q 3 ( Z 3 ) d Z 3 ⏟ = 1 ⋯ ∫ Z M q M ( Z M ) d Z M ⏟ = 1 = ∫ Z 1 q 1 ( Z 1 ) l o g q 1 ( Z 1 ) d Z 1 也就是说 ∫ Z ∏ i = 1 M q i ( Z i ) l o g q k ( Z k ) d Z = ∫ Z k q k ( Z k ) l o g q k ( Z k ) d Z k 则② = ∑ i = 1 M ∫ Z i q i ( Z i ) l o g q i ( Z i ) d Z i = ∫ Z j q j ( Z j ) l o g q j ( Z j ) d Z j + C ②=\int _{Z}q(Z)log\; q(Z)\mathrm{d}Z\\ =\int _{Z}\prod_{i=1}^{M}q_{i}(Z_{i})\sum_{i=1}^{M}log\; q_{i}(Z_{i})\mathrm{d}Z\\ =\int _{Z}\prod_{i=1}^{M}q_{i}(Z_{i})[log\; q_{1}(Z_{1})+log\; q_{2}(Z_{2})+\cdots +log\; q_{M}(Z_{M})]\mathrm{d}Z\\ 其中\int _{Z}\prod_{i=1}^{M}q_{i}(Z_{i})log\; q_{1}(Z_{1})\mathrm{d}Z\\ =\int _{Z_{1}Z_{2}\cdots Z_{M}}q_{1}(Z_{1})q_{2}(Z_{2})\cdots q_{M}(Z_{M})\cdot log\; q_{1}(Z_{1})\mathrm{d}Z_{1}\mathrm{d}Z_{2}\cdots \mathrm{d}Z_{M}\\ =\int _{Z_{1}}q_{1}(Z_{1})log\; q_{1}(Z_{1})\mathrm{d}Z_{1}\cdot \underset{=1}{\underbrace{\int _{Z_{2}}q_{2}(Z_{2})\mathrm{d}Z_{2}}}\cdot \underset{=1}{\underbrace{\int _{Z_{3}}q_{3}(Z_{3})\mathrm{d}Z_{3}}}\cdots \underset{=1}{\underbrace{\int _{Z_{M}}q_{M}(Z_{M})\mathrm{d}Z_{M}}}\\ =\int _{Z_{1}}q_{1}(Z_{1})log\; q_{1}(Z_{1})\mathrm{d}Z_{1}\\ 也就是说\int _{Z}\prod_{i=1}^{M}q_{i}(Z_{i})log\; q_{k}(Z_{k})\mathrm{d}Z=\int _{Z_{k}}q_{k}(Z_{k})log\; q_{k}(Z_{k})\mathrm{d}Z_{k}\\ 则②=\sum_{i=1}^{M}\int _{Z_{i}}q_{i}(Z_{i})log\; q_{i}(Z_{i})\mathrm{d}Z_{i}\\ =\int _{Z_{j}}q_{j}(Z_{j})log\; q_{j}(Z_{j})\mathrm{d}Z_{j}+C ② = ∫ Z q ( Z ) l o g q ( Z ) d Z = ∫ Z i = 1 ∏ M q i ( Z i ) i = 1 ∑ M l o g q i ( Z i ) d Z = ∫ Z i = 1 ∏ M q i ( Z i ) [ l o g q 1 ( Z 1 ) + l o g q 2 ( Z 2 ) + ⋯ + l o g q M ( Z M )] d Z 其中 ∫ Z i = 1 ∏ M q i ( Z i ) l o g q 1 ( Z 1 ) d Z = ∫ Z 1 Z 2 ⋯ Z M q 1 ( Z 1 ) q 2 ( Z 2 ) ⋯ q M ( Z M ) ⋅ l o g q 1 ( Z 1 ) d Z 1 d Z 2 ⋯ d Z M = ∫ Z 1 q 1 ( Z 1 ) l o g q 1 ( Z 1 ) d Z 1 ⋅ = 1 ∫ Z 2 q 2 ( Z 2 ) d Z 2 ⋅ = 1 ∫ Z 3 q 3 ( Z 3 ) d Z 3 ⋯ = 1 ∫ Z M q M ( Z M ) d Z M = ∫ Z 1 q 1 ( Z 1 ) l o g q 1 ( Z 1 ) d Z 1 也就是说 ∫ Z i = 1 ∏ M q i ( Z i ) l o g q k ( Z k ) d Z = ∫ Z k q k ( Z k ) l o g q k ( Z k ) d Z k 则 ② = i = 1 ∑ M ∫ Z i q i ( Z i ) l o g q i ( Z i ) d Z i = ∫ Z j q j ( Z j ) l o g q j ( Z j ) d Z j + C
然后我们可以得到① − ② ①-②\; ① − ② :
首先① = ∫ Z j q j ( Z j ) ⋅ E ∏ i ≠ j M q i ( Z i ) [ l o g p ( X , Z ) ] ⏟ 写作 l o g p ^ ( X , Z j ) ⋅ d Z j 然后① − ② = ∫ Z j q j ( Z j ) ⋅ l o g p ^ ( X , Z j ) q j ( Z j ) d Z j + C ∫ Z j q j ( Z j ) ⋅ l o g p ^ ( X , Z j ) q j ( Z j ) d Z j = − K L ( q j ( Z j ) ∣ ∣ p ^ ( X , Z j ) ) ≤ 0 首先①=\int _{Z_{j}}q_{j}(Z_{j})\cdot\underset{写作log\; \hat{p}(X,Z_{j})}{ \underbrace{E_{\prod_{i\neq j}^{M}q_{i}(Z_{i})}[log\; p(X,Z)]}}\cdot \mathrm{d}Z_{j} \\ 然后①-②=\int _{Z_{j}}q_{j}(Z_{j})\cdot log\frac{\hat{p}(X,Z_{j})}{q_{j}(Z_{j})}\mathrm{d}Z_{j}+C \\ \int _{Z_{j}}q_{j}(Z_{j})\cdot log\frac{\hat{p}(X,Z_{j})}{q_{j}(Z_{j})}\mathrm{d}Z_{j}=-KL(q_{j}(Z_{j})||\hat{p}(X,Z_{j}))\leq 0 首先 ① = ∫ Z j q j ( Z j ) ⋅ 写作 l o g p ^ ( X , Z j ) E ∏ i = j M q i ( Z i ) [ l o g p ( X , Z )] ⋅ d Z j 然后 ① − ② = ∫ Z j q j ( Z j ) ⋅ l o g q j ( Z j ) p ^ ( X , Z j ) d Z j + C ∫ Z j q j ( Z j ) ⋅ l o g q j ( Z j ) p ^ ( X , Z j ) d Z j = − K L ( q j ( Z j ) ∣∣ p ^ ( X , Z j )) ≤ 0
当q j ( Z j ) = p ^ ( X , Z j ) q_{j}(Z_{j})=\hat{p}(X,Z_{j}) q j ( Z j ) = p ^ ( X , Z j ) 才能得到最⼤值。
三、回顾EM算法
回想一下广义EM算法中,我们需要固定θ \theta θ 然后求解与p p p 最接近的q q q ,这里就可以使用变分推断的方法,我们有如下式子:
l o g p θ ( X ) = E L B O ⏟ L ( q ) + K L ( q ∣ ∣ p ) ⏟ ≥ 0 ≥ L ( q ) log\; p_{\theta }(X)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q) l o g p θ ( X ) = L ( q ) E L BO + ≥ 0 K L ( q ∣∣ p ) ≥ L ( q )
然后求解q q q :
q ^ = a r g m i n q K L ( q ∣ ∣ p ) = a r g m a x q L ( q ) \hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q) q ^ = q a r g min K L ( q ∣∣ p ) = q a r g ma x L ( q )
如果我们使用类似于平均场变分推断的方法,我们可以得到一些结果。在这里,Z i Z_i Z i 并不代表Z Z Z 的第i i i 个维度,而是指一组互相独立的变量。对于每一个 q j ( Z j ) q_j\left(Z_j\right) q j ( Z j ) ,我们都固定其余的 q i ( Z i ) q_i\left(Z_i\right) q i ( Z i ) ,然后求解这个值。我们可以使用坐标上升的方法进行迭代求解。上述的推导适用于单个样本,也适用于数据集。
l o g q j ( Z j ) = E ∏ i ≠ j M q i ( Z i ) [ l o g p θ ( X , Z ) ] = ∫ Z 1 ∫ Z 2 ⋯ ∫ Z j − 1 ∫ Z j + 1 ⋯ ∫ Z M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M ⋅ l o g p θ ( X , Z ) d Z 1 d Z 2 ⋯ d Z j − 1 d Z j + 1 ⋯ d Z M log\; q_{j}(Z_{j})=E_{\prod_{i\neq j}^{M}q_{i}(Z_{i})}[log\; p_{\theta }(X,Z)]\\ =\int _{Z_{1}}\int _{Z_{2}}\cdots \int _{Z_{j-1}}\int _{Z_{j+1}}\cdots \int _{Z_{M}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{M}\cdot log\; p_{\theta }(X,Z)\mathrm{d}Z_{1}\mathrm{d}Z_{2}\cdots \mathrm{d}Z_{j-1}\mathrm{d}Z_{j+1}\cdots \mathrm{d}Z_{M} l o g q j ( Z j ) = E ∏ i = j M q i ( Z i ) [ l o g p θ ( X , Z )] = ∫ Z 1 ∫ Z 2 ⋯ ∫ Z j − 1 ∫ Z j + 1 ⋯ ∫ Z M q 1 q 2 ⋯ q j − 1 q j + 1 ⋯ q M ⋅ l o g p θ ( X , Z ) d Z 1 d Z 2 ⋯ d Z j − 1 d Z j + 1 ⋯ d Z M
一次迭代求解的过程如下:
l o g q ^ 1 ( Z 1 ) = ∫ Z 2 ⋯ ∫ Z M q 2 ⋯ q M ⋅ l o g p θ ( X , Z ) d Z 2 ⋯ d Z M l o g q ^ 2 ( Z 2 ) = ∫ Z 1 ∫ Z 3 ⋯ ∫ Z M q ^ 1 q 3 ⋯ q M ⋅ l o g p θ ( X , Z ) d Z 1 d Z 3 ⋯ d Z M ⋮ l o g q ^ M ( Z M ) = ∫ Z 1 ⋯ ∫ Z M − 1 q ^ 1 ⋯ q ^ M − 1 ⋅ l o g p θ ( X , Z ) d Z 1 ⋯ d Z M − 1 log\; \hat{q}_{1}(Z_{1})=\int _{Z_{2}}\cdots \int _{Z_{M}}q_{2}\cdots q_{M}\cdot log\; p_{\theta }(X,Z)\mathrm{d}Z_{2}\cdots \mathrm{d}Z_{M}\\ log\; \hat{q}_{2}(Z_{2})=\int _{Z_{1}}\int _{Z_{3}}\cdots \int _{Z_{M}}\hat{q}_{1}q_{3}\cdots q_{M}\cdot log\; p_{\theta }(X,Z)\mathrm{d}Z_{1}\mathrm{d}Z_{3}\cdots \mathrm{d}Z_{M}\\ \vdots \\ log\; \hat{q}_{M}(Z_{M})=\int _{Z_{1}}\cdots \int _{Z_{M-1}}\hat{q}_{1}\cdots \hat{q}_{M-1}\cdot log\; p_{\theta }(X,Z)\mathrm{d}Z_{1}\cdots \mathrm{d}Z_{M-1} l o g q ^ 1 ( Z 1 ) = ∫ Z 2 ⋯ ∫ Z M q 2 ⋯ q M ⋅ l o g p θ ( X , Z ) d Z 2 ⋯ d Z M l o g q ^ 2 ( Z 2 ) = ∫ Z 1 ∫ Z 3 ⋯ ∫ Z M q ^ 1 q 3 ⋯ q M ⋅ l o g p θ ( X , Z ) d Z 1 d Z 3 ⋯ d Z M ⋮ l o g q ^ M ( Z M ) = ∫ Z 1 ⋯ ∫ Z M − 1 q ^ 1 ⋯ q ^ M − 1 ⋅ l o g p θ ( X , Z ) d Z 1 ⋯ d Z M − 1
我们看到,对每一个 q j ( Z j ) q_j\left(Z_j\right) q j ( Z j ) ,都是固定其余的 q i ( Z i ) q_i\left(Z_i\right) q i ( Z i ) ,求这个值,于是可以使用坐标上升的方法进行迭代求解,上面的推导针对单个样本,但是对数据集也是适用的。
需要注意的是变分推断中参数 θ \theta θ 是一个随机变量,因此 Z Z Z 既包括隐变量也包括参数 θ \theta θ ,而在广义EM算法中, θ \theta θ 被假设存在一个最优的常量,我们虽然也应用了平均场理论的方法,但是这里的 Z Z Z 只包括隐变量, θ \theta θ 在这一步中被固定住了,相当于广义EM算法的E-step。
基于平均场假设的变分推断存在一些问题:
(1)假设太强,非常复杂的情况下,假设不适用;
(2)期望中的多重积分,计算量大,可能无法计算。
四、随机梯度变分推断 (SGVI)
直接求导数的方法
从 Z Z Z 到 X X X 的过程叫做生成过程或解码过程,相当于Decoder(从不可见的Z Z Z 生成可见的X X X )。从 X X X 到 Z Z Z 的过程叫做推断过程或编码过程,相当于Encoder(从可见的X X X 推断出不可见的Z Z Z )。基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在一些情况下过于强烈,同时积分也可能无法计算。除了坐标上升,优化方法还有梯度上升,我们希望通过梯度上升得到变分推断的另一种算法。
首先假定 q ( Z ) = q ϕ ( Z ) q(Z)=q_\phi(Z) q ( Z ) = q ϕ ( Z ) ,是和 ϕ \phi ϕ 这个参数相关联的概率分布。于是有:
argmax q ( Z ) L ( q ) = argmax ϕ L ( ϕ ) \underset{q(Z)}{\operatorname{argmax}} L(q)=\underset{\phi}{\operatorname{argmax}} L(\phi) q ( Z ) argmax L ( q ) = ϕ argmax L ( ϕ )
其中 L ( ϕ ) = E q ϕ [ log p θ ( X , Z ) − log q ϕ ( Z ) ] L(\phi)=E_{q_\phi}\left[\log p_\theta(X, Z)-\log q_\phi(Z)\right] L ( ϕ ) = E q ϕ [ log p θ ( X , Z ) − log q ϕ ( Z ) ] ,这里的 X X X 表示的是一个样本。
接下来我们关于ϕ \phi ϕ 求偏导∇ ϕ \nabla_{\phi } ∇ ϕ
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] = ∇ ϕ ∫ q ϕ ( Z ) [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] d Z = ∫ ∇ ϕ q ϕ ( Z ) ⋅ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] d Z ⏟ ① + ∫ q ϕ ( Z ) ∇ ϕ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] d Z ⏟ ② 其中② = ∫ q ϕ ( Z ) ∇ ϕ [ l o g p θ ( X , Z ) ⏟ 与 ϕ 无关 − l o g q ϕ ( Z ) ] d Z = − ∫ q ϕ ( Z ) ∇ ϕ l o g q ϕ ( Z ) d Z = − ∫ q ϕ ( Z ) 1 q ϕ ( Z ) ∇ ϕ q ϕ ( Z ) d Z = − ∫ ∇ ϕ q ϕ ( Z ) d Z = − ∇ ϕ ∫ q ϕ ( Z ) d Z = − ∇ ϕ 1 = 0 因此 ∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( Z ) ⋅ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] d Z = ∫ q ϕ ( Z ) ∇ ϕ l o g q ϕ ( Z ) ⋅ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] d Z = E q ϕ [ ( ∇ ϕ l o g q ϕ ( Z ) ) ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\\ =\nabla_{\phi }\int q_{\phi }(Z)[log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\mathrm{d}Z \\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(Z)\cdot [log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\mathrm{d}Z}}+\underset{②}{\underbrace{\int q_{\phi }(Z)\nabla_{\phi }[log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\mathrm{d}Z}}\\ 其中②=\int q_{\phi }(Z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(X,Z)}}-log\; q_{\phi }(Z)]\mathrm{d}Z\\ =-\int q_{\phi }(Z)\nabla_{\phi }log\; q_{\phi }(Z)\mathrm{d}Z\\ =-\int q_{\phi }(Z)\frac{1}{q_{\phi }(Z)}\nabla_{\phi }q_{\phi }(Z)\mathrm{d}Z\\ =-\int \nabla_{\phi }q_{\phi }(Z)\mathrm{d}Z\\ =-\nabla_{\phi }\int q_{\phi }(Z)\mathrm{d}Z\\ =-\nabla_{\phi }1\\ =0\\ 因此\nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(Z)}}\cdot [log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\mathrm{d}Z\\ =\int {\color{Red}{q_{\phi }(Z)\nabla_{\phi }log\; q_{\phi }(Z)}}\cdot [log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\mathrm{d}Z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(Z))(log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z))] ∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] = ∇ ϕ ∫ q ϕ ( Z ) [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z = ① ∫ ∇ ϕ q ϕ ( Z ) ⋅ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z + ② ∫ q ϕ ( Z ) ∇ ϕ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z 其中 ② = ∫ q ϕ ( Z ) ∇ ϕ [ 与 ϕ 无关 l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z = − ∫ q ϕ ( Z ) ∇ ϕ l o g q ϕ ( Z ) d Z = − ∫ q ϕ ( Z ) q ϕ ( Z ) 1 ∇ ϕ q ϕ ( Z ) d Z = − ∫ ∇ ϕ q ϕ ( Z ) d Z = − ∇ ϕ ∫ q ϕ ( Z ) d Z = − ∇ ϕ 1 = 0 因此 ∇ ϕ L ( ϕ ) = ① = ∫ ∇ ϕ q ϕ ( Z ) ⋅ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z = ∫ q ϕ ( Z ) ∇ ϕ l o g q ϕ ( Z ) ⋅ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z = E q ϕ [( ∇ ϕ l o g q ϕ ( Z )) ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ))]
这个期望可以通过蒙特卡洛采样来近似,从而得到梯度,然后利用梯度上升的方法来得到参数:
Z ( l ) ∼ q ϕ ( Z ) E q ϕ [ ( ∇ ϕ l o g q ϕ ( Z ) ) ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ) ] ≈ 1 L ∑ i = 1 L ( ∇ ϕ l o g q ϕ ( Z ( l ) ) ) ( l o g p θ ( X , Z ( l ) ) − l o g q ϕ ( Z ( l ) ) ) Z^{(l)}\sim q_{\phi }(Z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(Z))(log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z))]\approx \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(Z^{(l)}))(log\; p_{\theta }(X,Z^{(l)})-log\; q_{\phi }(Z^{(l)})) Z ( l ) ∼ q ϕ ( Z ) E q ϕ [( ∇ ϕ l o g q ϕ ( Z )) ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ))] ≈ L 1 i = 1 ∑ L ( ∇ ϕ l o g q ϕ ( Z ( l ) )) ( l o g p θ ( X , Z ( l ) ) − l o g q ϕ ( Z ( l ) ))
但是,存在一个问题,求和符号中有一个对数项l o g ; p θ log; p_{\theta } l o g ; p θ ,所以如果我们直接采样,如果采样到q ϕ ( Z ) q_{\phi }(Z) q ϕ ( Z ) 接近于0 0 0 的样本点,这会造成对数值极不稳定,也就是说直接采样的方差很大,需要非常多的样本。并且,如果计算出的梯度误差已经非常大,那么所得到的ϕ ^ \hat{\phi} ϕ ^ 就会有很大的误差,ϕ ^ \hat{\phi} ϕ ^ 是q ( z ) q(z) q ( z ) 的参数,误差会一层一层地传递,最后的结果可能会不理想。为了解决方差太大的问题,我们采用了一个技巧,叫做重参数化技巧(Reparameterization) 。
重参数化技巧
我们定义Z = g ϕ ( ε , X ) , ε ∼ p ( ε ) Z=g_{\phi }(\varepsilon ,X),\varepsilon \sim p(\varepsilon ) Z = g ϕ ( ε , X ) , ε ∼ p ( ε ) ,对于Z ∼ q ϕ ( Z ∣ X ) Z\sim q_{\phi }(Z|X) Z ∼ q ϕ ( Z ∣ X ) ,我们有∣ q ϕ ( Z ∣ X ) d Z ∣ = ∣ p ( ε ) d ε ∣ \left | q_{\phi }(Z|X)\mathrm{d}Z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right | ∣ q ϕ ( Z ∣ X ) d Z ∣ = ∣ p ( ε ) d ε ∣ 。这是为了将Z Z Z 的随机性转移到ε \varepsilon ε 上,使得我们可以将求梯度的操作移到期望的中括号里面,具体如下:
∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] = ∇ ϕ ∫ q ϕ ( Z ) [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] d Z = ∇ ϕ ∫ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] q ϕ ( Z ) d Z = ∇ ϕ ∫ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] p ( ε ) d ε = ∇ ϕ E p ( ε ) ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ] = E p ( ε ) [ ∇ ϕ ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ) ] = E p ( ε ) [ ∇ Z ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ) ) ∇ ϕ Z ] = E p ( ε ) [ ∇ Z ( l o g p θ ( X ( i ) , Z ) − l o g q ϕ ( Z ∣ X ( i ) ) ) ∇ ϕ g ϕ ( ε ( l ) , X ( i ) ) ] \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\\ =\nabla_{\phi }\int q_{\phi }(Z)[log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\mathrm{d}Z\\ =\nabla_{\phi }\int [log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]{\color{Red}{q_{\phi }(Z)\mathrm{d}Z}}\\ =\nabla_{\phi }\int [log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]{\color{Red}{p(\varepsilon )\mathrm{d}\varepsilon }}\\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z))]\\ =E_{p(\varepsilon )}[\nabla_{Z}(log\; p_{\theta }(X,Z)-log\; q_{\phi }(Z))\nabla_{\phi }Z]\\ =E_{p(\varepsilon )}[\nabla_{Z}(log\; p_{\theta }(X^{(i)},Z)-log\; q_{\phi }(Z|X^{(i)}))\nabla_{\phi }g_{\phi }(\varepsilon^{(l)} ,X^{(i)})] ∇ ϕ L ( ϕ ) = ∇ ϕ E q ϕ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] = ∇ ϕ ∫ q ϕ ( Z ) [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] d Z = ∇ ϕ ∫ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] q ϕ ( Z ) d Z = ∇ ϕ ∫ [ l o g p θ ( X , Z ) − l o g q ϕ ( Z )] p ( ε ) d ε = ∇ ϕ E p ( ε ) ( l o g p θ ( X , Z ) − l o g q ϕ ( Z )] = E p ( ε ) [ ∇ ϕ ( l o g p θ ( X , Z ) − l o g q ϕ ( Z ))] = E p ( ε ) [ ∇ Z ( l o g p θ ( X , Z ) − l o g q ϕ ( Z )) ∇ ϕ Z ] = E p ( ε ) [ ∇ Z ( l o g p θ ( X ( i ) , Z ) − l o g q ϕ ( Z ∣ X ( i ) )) ∇ ϕ g ϕ ( ε ( l ) , X ( i ) )]
解释一下倒数第二步,链式求导法则
∂ f ∂ ϕ = ∂ f ∂ z ⋅ ∂ z ∂ ϕ z = g ( ϕ ) \frac{\partial f}{\partial \phi}=\frac{\partial f}{\partial z} \cdot \frac{\partial z}{\partial \phi} \quad z=g(\phi) ∂ ϕ ∂ f = ∂ z ∂ f ⋅ ∂ ϕ ∂ z z = g ( ϕ )
最后一步所有Z都可以看成g ϕ ( ε ( l ) , X ( i ) ) , l = 1 , 2 , . . . , L g_{\phi }(\varepsilon^{(l)} ,X^{(i)}), l = 1,2,...,L g ϕ ( ε ( l ) , X ( i ) ) , l = 1 , 2 , ... , L , X ( i ) X^{(i)} X ( i ) 为第i个样本,只是在最后一步列出了完整式子
对最终这个中括号里的式子进行蒙特卡洛采样,然后计算期望,得到梯度。这里的采样就是从p ( ε ) p(\varepsilon ) p ( ε ) 中进行采样了。
SGVI的迭代过程为:
ϕ t + 1 ← ϕ t + λ t ⋅ ∇ ϕ L ( ϕ ) \phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi ) ϕ t + 1 ← ϕ t + λ t ⋅ ∇ ϕ L ( ϕ )
这就是典型的梯度上升,蒙特卡洛采样的方法会在后面的文章中介绍。
总结
EM算法解决的是含有隐变量的参数估计问题(是一个优化方法);而VI解决的是后验概率的推断问题,求的是概率分布;SGVI的思想是在VI的基础之上,通过假设分布类型,将分布估计转换为参数估计。
“开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 10 天,点击查看活动详情 ”