11.Expectation_Maximization(EM算法): 解决含有隐变量的统计问题

568 阅读5分钟

一、概述

  1. 介绍

概率模型有时既包含观测变量(observed variable),又包含隐变量(latent variable)。当概率模型只包含观测变量时,那么给定观测数据,就可以直接使用极大似然估计法或者贝叶斯估计法进行模型参数的求解。然而如果模型包含隐变量,就不能直接使用这些简单的方法了。EM算法就是用来解决这种含有隐变量的概率模型参数的极大似然参数估计法。这里只讨论极大似然估计,极大后验估计与其类似。

  1. 算法

EM算法的输入如下:

XX :观测数据

ZZ : 末观测数据 (隐变量)

p(x,zθ)p(x, z \mid \theta) : 联合分布

p(zx,θ)p(z \mid x, \theta) :后验分布

θ\theta :parameter

在算法运行开始时需要选择模型的初始化参数 θ(0)\theta^{(0)} 。EM算法是一种迭代更新的算法,其计算公式为:

θt+1=argmaxEzx,θt[logp(x,zθ)]θ=argmaxθzlogp(x,zθ)p(zx,θt)dz\begin{gathered} \theta^{t+1}=\underset{\theta}{\operatorname{argmax} E_{z \mid x, \theta^t}[\log p(x, z \mid \theta)]} \\ =\underset{\theta}{\operatorname{argmax}} \int_z \log p(x, z \mid \theta) \cdot p\left(z \mid x, \theta^t\right) \mathrm{d} z \end{gathered}

这个公式包含了迭代的两步:

  • ①E step: 计算 p(x,zθ)p(x, z \mid \theta) 在概率分布 p(zx,θt)p\left(z \mid x, \theta^t\right) 下的期望;

  • ②M step: 计算使这个期望最大化的参数得到下一个EM步骤的输入。

总结来说,EM算法包含以下步骤:

  • ①选择初始化参数θ(0)\theta ^{(0)}
  • ②E step;
  • ③M step;
  • ④重复②③步直至收敛。

二、EM算法的收敛性

现在要证明迭代求得的 θt\theta^t 序列会使得对应的 p(xθt)p\left(x \mid \theta^t\right) 是单调递增的 (如果 p(xθt)p\left(x \mid \theta^t\right) 是单调递 增的,那么训练数据的似然就是单调递增的),也就是说要证明 p(xθt)p(xθt+1)p\left(x \mid \theta^t\right) \leq p\left(x \mid \theta^{t+1}\right) 。首先我们有:

logp(xθ)=logp(x,zθ)logp(zx,θ)\log p(x \mid \theta)=\log p(x, z \mid \theta)-\log p(z \mid x, \theta)

接下来等式两边同时求关于 p(zx,θt)p\left(z \mid x, \theta^t\right) 的期望:

 左边 =zp(zx,θt)logp(xθ)dz=logp(xθ)zp(zx,θt)dz=logp(xθ) 右边 =zp(zx,θt)logp(x,zθ)dz记作 Q(θ,θt)zp(zx,θt)logp(zx,θ)dz记作 H(θ,θt)\begin{gathered} \text { 左边 }=\int_z p\left(z \mid x, \theta^t\right) \cdot \log p(x \mid \theta) \mathrm{d} z \\ =\log p(x \mid \theta) \int_z p\left(z \mid x, \theta^t\right) \mathrm{d} z \\ =\log p(x \mid \theta) \\ \text { 右边 }=\underbrace{\int_z p\left(z \mid x, \theta^t\right) \cdot \log p(x, z \mid \theta) \mathrm{d} z}_{\text {记作 } Q\left(\theta, \theta^t\right)}-\underbrace{\int_z p\left(z \mid x, \theta^t\right) \cdot \log p(z \mid x, \theta) \mathrm{d} z}_{\text {记作 } H\left(\theta, \theta^t\right)} \end{gathered}

因此有:

logp(xθ)=zp(zx,θt)p(x,zθ)dzzp(zx,θt)logp(zx,θ)dz\log p(x \mid \theta)=\int_z p\left(z \mid x, \theta^t\right) \cdot p(x, z \mid \theta) \mathrm{d} z-\int_z p\left(z \mid x, \theta^t\right) \cdot \log p(z \mid x, \theta) \mathrm{d} z

这里定义了 Q(θ,θt)Q\left(\theta, \theta^t\right) ,称为 Q\mathrm{Q} 函数 ( Q\mathrm{Q} function),这个函数也就是上面的概述中迭代公式里 用到的函数,因此满足 Q(θt+1,θt)Q(θt,θt)Q\left(\theta^{t+1}, \theta^t\right) \geq Q\left(\theta^t, \theta^t\right)

接下来将上面的等式两边 θ\theta 分别取 θt+1\theta^{t+1}θt\theta^t 并相减:

logp(xθt+1)logp(xθt)=[Q(θt+1,θt)Q(θt,θt)][H(θt+1,θt)H(θt,θt)]\log p\left(x \mid \theta^{t+1}\right)-\log p\left(x \mid \theta^t\right)=\left[Q\left(\theta^{t+1}, \theta^t\right)-Q\left(\theta^t, \theta^t\right)\right]-\left[H\left(\theta^{t+1}, \theta^t\right)-H\left(\theta^t, \theta^t\right)\right]

我们需要证明 logp(xθt+1)logp(xθt)0\log p\left(x \mid \theta^{t+1}\right)-\log p\left(x \mid \theta^t\right) \geq 0 ,同时已知Q(θt+1,θt)Q(θt,θt)0Q\left(\theta^{t+1}, \theta^t\right)-Q\left(\theta^t, \theta^t\right) \geq 0,现在来观察H(θt+1,θt)H(θt,θt) : H\left(\theta^{t+1}, \theta^t\right)-H\left(\theta^t, \theta^t\right) \text { : }

H(θt+1,θt)H(θt,θt)=zp(zx,θt)log  p(zx,θt+1)dzzp(zx,θt)log  p(zx,θt)dz=zp(zx,θt)logp(zx,θt+1)p(zx,θt)dzlogzp(zx,θt)p(zx,θt+1)p(zx,θt)dz=logzp(zx,θt+1)dz=log  1=0H(\theta ^{t+1},\theta ^{t})-H(\theta ^{t},\theta ^{t})\\ =\int _{z}p(z|x,\theta ^{t})\cdot log\; p(z|x,\theta ^{t+1})\mathrm{d}z-\int _{z}p(z|x,\theta ^{t})\cdot log\; p(z|x,\theta ^{t})\mathrm{d}z\\ =\int _{z}p(z|x,\theta ^{t})\cdot log\frac{p(z|x,\theta ^{t+1})}{p(z|x,\theta ^{t})}\mathrm{d}z\\ \leq log\int _{z}p(z|x,\theta ^{t})\frac{p(z|x,\theta ^{t+1})}{p(z|x,\theta ^{t})}\mathrm{d}z\\ =log\int _{z}p(z|x,\theta ^{t+1})\mathrm{d}z\\ =log\; 1\\ =0

这里的不等号应用了Jensen不等式:

logjλjyjjλjlog  yj,其中λj0jλj=1log\sum _{j}\lambda _{j}y_{j}\geq \sum _{j}\lambda _{j}log\; y_{j},其中\lambda _{j}\geq 0,\sum _{j}\lambda _{j}=1

也可以使用KL散度来证明 zp(zx,θt)logp(zx,θt+1)p(zx,θt)dz0\int_z p\left(z \mid x, \theta^t\right) \cdot \log \frac{p\left(z \mid x, \theta^{t+1}\right)}{p\left(z \mid x, \theta^t\right)} \mathrm{d} z \leq 0 ,两个概率分布 P(x)P(x)Q(x)Q(x) 的KL散度是恒 0\geq 0 的,定义为:

DKL(PQ)=ExP[logP(x)Q(x)]D_{K L}(P \| Q)=E_{x \sim P}\left[\log \frac{P(x)}{Q(x)}\right]

因此有:

zp(zx,θt)logp(zx,θt+1)p(zx,θt)dz=KL(p(zx,θt)p(zx,θt+1))0\int_z p\left(z \mid x, \theta^t\right) \cdot \log \frac{p\left(z \mid x, \theta^{t+1}\right)}{p\left(z \mid x, \theta^t\right)} \mathrm{d} z=-K L\left(p\left(z \mid x, \theta^t\right)|| p\left(z \mid x, \theta^{t+1}\right)\right) \leq 0

因此得证 logp(xθt+1)logp(xθt)0\log p\left(x \mid \theta^{t+1}\right)-\log p\left(x \mid \theta^t\right) \geq 0 。这说明使用EM算法迭代更新参数可以使得 logp(xθ)\log p(x \mid \theta) 逐步增大。

另外还有其他定理保证了EM的算法收敛性。首先对于 θi(i=1,2,)\theta^i(i=1,2, \cdots) 序列和其对应的对数似然序列 L(θt)=logp(xθt)(t=1,2,)L\left(\theta^t\right)=\log p\left(x \mid \theta^t\right)(t=1,2, \cdots) 有如下定理:

  • ①如果 p(xθ)p(x \mid \theta) 有上界,则 L(θt)=logp(xθt)L\left(\theta^t\right)=\log p\left(x \mid \theta^t\right) 收敛到某一值 LL^*

  • ②在函数 Q(θ,θ)Q\left(\theta, \theta^{\prime}\right)L(θ)L(\theta) 满足一定条件下,由EM算法得到的参数估计序列 θt\theta^t 的收敛值 θ\theta^*L(θ)L(\theta) 的稳定点。

三、EM算法的导出

  1. ELBO+KL散度的方法

对于前面用过的式子,首先引入一个新的概率分布q(z)q(z)

log  p(xθ)=log  p(x,zθ)log  p(zx,θ)=log  p(x,zθ)q(z)log  p(zx,θ)q(z)    q(z)0log\; p(x|\theta )=log\; p(x,z|\theta )-log\; p(z|x,\theta )\\ =log\; \frac {p(x,z|\theta )}{q(z)}-log\; \frac{p(z|x,\theta )}{q(z)}\; \; q(z)\neq 0

以上引入一个关于zz的概率分布q(z)q(z),然后式子两边同时求对q(z)q(z)的期望:

左边=zq(z)log  p(xθ)dz=log  p(xθ)zq(z)dz=log  p(xθ)右边=zq(z)log  p(x,zθ)q(z)dzELBO(evidence  lower  bound)zq(z)log  p(zx,θ)q(z)dzKL(q(z)p(zx,θ))左边=\int _{z}q(z)\cdot log\; p(x|\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z)\mathrm{d}z=log\; p(x|\theta )\\ 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}

因此我们得出 logp(xθ)=ELBO+KL(qp)\log p(x \mid \theta)=E L B O+K L(q \| p) ,由于KL散度恒 0\geq 0 ,因此logp(xθ)ELBO\log p(x \mid \theta) \geq E L B O ,则 ELBOE L B O 就是似然函数 logp(xθ)\log p(x \mid \theta) 的下界。使得logp(xθ)=ELBO\log p(x \mid \theta)=E L B O 时,就必须有 KL(qp)=0K L(q \| p)=0 ,也就是 q(z)=p(zx,θ)q(z)=p(z \mid x, \theta) 时。在

每次迭代中我们取 q(z)=p(zx,θt)q(z)=p\left(z \mid x, \theta^t\right) ,就可以保证 logp(xθt)\log p\left(x \mid \theta^t\right)ELBOE L B O 相等,也就是:

log  p(xθ)=zp(zx,θt)log  p(x,zθ)p(zx,θt)dzELBOzp(zx,θt)log  p(zx,θ)p(zx,θt)dzKL(p(zx,θt)p(zx,θ))log\; p(x|\theta )=\underset{ELBO}{\underbrace{\int _{z}p(z|x,\theta ^{t})log\; \frac {p(x,z|\theta )}{p(z|x,\theta ^{t})}\mathrm{d}z}}\underset{KL(p(z|x,\theta ^{t})||p(z|x,\theta ))}{\underbrace{-\int _{z}p(z|x,\theta ^{t})log\; \frac{p(z|x,\theta )}{p(z|x,\theta ^{t})}\mathrm{d}z}}

θ=θt\theta=\theta^t 时, logp(xθt)\log p\left(x \mid \theta^t\right) 取ELBO,即:

log  p(xθt)=zp(zx,θt)log  p(x,zθt)p(zx,θt)dzELBOzp(zx,θt)log  p(zx,θt)p(zx,θt)dz=0=ELBOlog\; p(x|\theta ^{t})=\underset{ELBO}{\underbrace{\int _{z}p(z|x,\theta ^{t})log\; \frac{p(x,z|\theta ^{t})}{p(z|x,\theta ^{t})}\mathrm{d}z}}\underset{=0}{\underbrace{-\int _{z}p(z|x,\theta ^{t})log\; \frac{p(z|x,\theta ^{t})}{p(z|x,\theta ^{t})}\mathrm{d}z}}=ELBO

也就是说 logp(xθ)\log p(x \mid \theta)ELBOE L B O 都是关于 θ\theta 的函数,且满足 logp(xθ)ELBO\log p(x \mid \theta) \geq E L B O ,也就 是说 logp(xθ)\log p(x \mid \theta) 的图像总是在 ELBOE L B O 的图像的上面。

对于 q(z)q(z) ,我们取q(z)=p(zx,θt)q(z)=p\left(z \mid x, \theta^t\right) ,这也就保证了只有在 θ=θt\theta=\theta^tlogp(xθ)\log p(x \mid \theta)ELBOE L B O 才会相等,因 此使 ELBOE L B O 取极大值的 θt+1\theta^{t+1} 一定能使得 logp(xθt+1)logp(xθt)\log p\left(x \mid \theta^{t+1}\right) \geq \log p\left(x \mid \theta^t\right) 。该过程如下图 所示:

22097296-7ad3b7a07c4a078c.webp

然后我们观察一下ELBOELBO取极大值的过程:

θt+1=argmaxθELBO=argmaxθzp(zx,θt)log  p(x,zθ)p(zx,θt)dz=argmaxθzp(zx,θt)log  p(x,zθ)dzargmaxθzp(zx,θt)p(zx,θt)dzθ无关=argmaxθzp(zx,θt)log  p(x,zθ)dz=argmaxθEzx,θt[log  p(x,zθ)]\theta ^{t+1}=\underset{\theta }{argmax}ELBO \\ =\underset{\theta }{argmax}\int _{z}p(z|x,\theta ^{t})log\; \frac{p(x,z|\theta )}{p(z|x,\theta ^{t})}\mathrm{d}z\\ =\underset{\theta }{argmax}\int _{z}p(z|x,\theta ^{t})log\; p(x,z|\theta )\mathrm{d}z-\underset{与\theta 无关}{\underbrace{\underset{\theta }{argmax}\int _{z}p(z|x,\theta ^{t})p(z|x,\theta ^{t})\mathrm{d}z}}\\ {\color{Red}{=\underset{\theta }{argmax}\int _{z}p(z|x,\theta ^{t})log\; p(x,z|\theta )\mathrm{d}z}} \\ {\color{Red}{=\underset{\theta }{argmax}E_{z|x,\theta ^{t}}[log\; p(x,z|\theta )]}}

由此我们就导出了EM算法的迭代公式。

  1. ELBO+Jensen不等式的方法

首先要具体介绍一下Jensen不等式:对于一个凹函数 f(x)f(x)(国内外对凹凸函数的定义恰好相反,这里的凹函数指的是国外定义的凹函数),我们查看其图像如下:

22097296-2a5c5a4baee9ec29.webp

t[0,1]c=ta+(1t)bϕ=tf(a)+(1t)f(b)t\in [0,1]\\ c=ta+(1-t)b\\ \phi =tf(a)+(1-t)f(b)

凹函数恒有 f(c)ϕmathrm ,也就是 f(ta+(1t)b)tf(a)+(1t)f(b)f(c) \geq \phi \mathrm{~ , 也 就 是 ~} f(t a+(1-t) b) \geq t f(a)+(1-t) f(b) ,当 t=12t=\frac{1}{2} 时有 f(a2+b2)f(a)2+f(b)2f\left(\frac{a}{2}+\frac{b}{2}\right) \geq \frac{f(a)}{2}+\frac{f(b)}{2} ,可以理解为对于凹函数来说 先求期望再求函数值\geq 先求函数值再求期望,即 f(E)E[f]f(E) \geq E[f]

上面的说明只是对Jensen不等式的一个形象的描述,而非严谨的证明。接下来应用Jensen不等式来导出EM算法:

logp(xθ)=logzp(x,zθ)dz=logzp(x,zθ)q(z)q(z)dz=logEq(z)[p(x,zθ)q(z)]Eq(z)[logp(x,zθ)q(z)]ELBO\begin{gathered} \log p(x \mid \theta)=\log \int_z p(x, z \mid \theta) \mathrm{d} z \\ =\log \int_z \frac{p(x, z \mid \theta)}{q(z)} \cdot q(z) \mathrm{d} z \\ =\log E_{q(z)}\left[\frac{p(x, z \mid \theta)}{q(z)}\right] \\ \geq \underbrace{E_{q(z)}\left[\log \frac{p(x, z \mid \theta)}{q(z)}\right]}_{E L B O} \end{gathered}

这里应用了Jensen不等式得到了上面出现过的 ELBOE L B O ,这里的 f(x)f(x) 函数也就是 log\log 函数, 显然这是一个凹函数。当 logP(x,zθ)q(z)\log \frac{P(x, z \mid \theta)}{q(z)} 这个函数是一个常数时会取得等号,利用这一点我们 也同样可以得到 q(z)=p(zx,θ)q(z)=p(z \mid x, \theta) 时能够使得 logp(xθ)=ELBO\log p(x \mid \theta)=E L B O 的结论:

p(x,zθ)q(z)=Cq(z)=p(x,zθ)Czq(z)dz=z1Cp(x,zθ)dz1=1Czp(x,zθ)dzC=p(xθ)C代入q(z)=p(x,zθ)Cq(z)=p(x,zθ)p(xθ)=p(zx,θ)\frac{p(x,z|\theta )}{q(z)}=C\\ \Rightarrow q(z)=\frac{p(x,z|\theta )}{C}\\ \Rightarrow \int _{z}q(z)\mathrm{d}z=\int _{z}\frac{1}{C}p(x,z|\theta )\mathrm{d}z\\ \Rightarrow 1=\frac{1}{C}\int _{z}p(x,z|\theta )\mathrm{d}z\\ \Rightarrow C=p(x|\theta )\\ 将C代入q(z)=\frac{p(x,z|\theta )}{C}得\\ {\color{Red}{q(z)=\frac{p(x,z|\theta )}{p(x|\theta )}=p(z|x,\theta )}}

这种方法到这里就和上面的方法一样了,总结来说就是:

log  p(xθ)Eq(z)[logp(x,zθ)q(z)]ELBOlog\; p(x|\theta )\geq \underset{ELBO}{\underbrace{E_{q(z)}[log\frac{p(x,z|\theta )}{q(z)}]}}

上面的不等式在q(z)=p(zxθ)q(z)=p(z|x|\theta )时取等号,因此在迭代更新过程中取q(z)=p(zx,θt)q(z)=p(z|x,\theta ^{t})接下来的推导过程就和第1种方法一样了。

四、广义EM算法

上面介绍的EM算法属于狭义的EM算法,它是广义EM的一个特例。在上面介绍的EM算法的E步中我们假定q(z)=p(zx,θt)q(z)=p(z|x,\theta ^{t}),但是如果这个后验p(zx,θt)p(z|x,\theta ^{t})无法求解,那么必须使⽤采样(MCMC)或者变分推断等⽅法来近似推断这个后验。前面我们得出了以下关系:

logp(xθ)=zq(z)logp(x,zθ)q(z)dzzq(z)logp(zx,θ)q(z)dz=ELBO+KL(qp)\log p(x \mid \theta)=\int_z q(z) \log \frac{p(x, z \mid \theta)}{q(z)} \mathrm{d} z-\int_z q(z) \log \frac{p(z \mid x, \theta)}{q(z)} \mathrm{d} z=E L B O+K L(q \| p)

当我们对于固定的 θ\theta ,我们希望 KL(qp)K L(q \| p) 越小越好,这样就能使得 ELBOE L B O 更大:

固定θ,q^=argminqKL(qp)=argmaxqELBO固定 \theta, \hat{q}=\underset{q}{\operatorname{argmin}} K L(q \| p)=\underset{q}{\operatorname{argmax}} E L B O

ELBOE L B O 是关于 qqθ\theta 的函数,写作 L(q,θ)L(q, \theta) 。以下是广义EM算法的基本思路:

E step: qt+1=argmaxL(q,θt)q^{t+1}=\operatorname{argmax} L\left(q, \theta^t\right) M step: θt+1=argmaxqL(qt+1,θ)\theta^{t+1}=\underset{q}{\operatorname{argmax}} L\left(q^{t+1}, \theta\right)

再次观察一下 ELBOE L B O :

ELBO=L(q,θ)=Eq[log  p(x,z)log  q]=Eq[log  p(x,z)]Eq[log  q]H[q]ELBO=L(q,\theta )=E_{q}[log\; p(x,z)-log\; q]\\ =E_{q}[log\; p(x,z)]\underset{H[q]}{\underbrace{-E_{q}[log\; q]}}

因此,我们看到,⼴义 EM 相当于在原来的式⼦中加⼊熵H[q]H[q]这⼀项。

五、EM的变种

EM 算法类似于坐标上升法,固定部分坐标,优化其他坐标,再⼀遍⼀遍的迭代。如果在 EM 框架中,⽆法求解zz后验概率,那么需要采⽤⼀些变种的 EM 来估算这个后验:

①基于平均场的变分推断,VBEM/VEM

②基于蒙特卡洛的EM,MCEM

“开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 8 天,点击查看活动详情