从熵和损失函数到变分推断

1,037 阅读6分钟

0.动机

  一直以来看了很多工作,感觉都用到了比较多的数学处理,其中,熵和变分推断往往是我比较难理解的一部分。我也尝试过看一些相关的资料,但是总感觉这部分内容还是需要有一些实践和手推才能刚好理解 (毕竟符号和格式总是千奇百怪的)。忙中偷闲,以这种独特的方式来倒逼自己好好动手把这部分理解一下 (同时也尽可能写得更全一些),也算是对自己的一个考验吧。后续如果有更多内容我会进行编辑补充。

  感谢所有的参考资料及其作者,都是很不错的资料。不胜感激。同时请务必注意,很多数学推断还是要动手算的,切忌眼过就是理解

1.极大似然估计 (maximum likelihood estimation,MLE)

1.1 一种理解的角度

  在个人理解中: 极大似然估计的核心发生的事情,应该是最可能的事情。这个感觉和 Georg Wilhelm Friedrich Hegel 的那句

Was vernünftig ist, das ist wirklich; und was wirklich ist, das ist vernünftig. (English version: What is rational is actual and what is actual is rational.)

有着异曲同工之妙。

1.2 形式化

  给定一个概率分布 DD,已知其概率密度函数 (连续分布) 或概率质量函数 (离散分布) 为 fDf_D,以及一个分布参数 θ\theta,我们可以从这个分布中抽出一个具有 nn 个值的采样 X1,X2,,XnX_1,X_2,\ldots,X_n,利用 fDf_D 计算出其似然函数 :

L(θx1,x2,,xn)=L(θx)=fθ(x1,x2,,xn)\mathcal{L}(\theta|x_1,x_2,\ldots,x_n) =\mathcal{L}(\theta|\mathbf{x})= f_{\theta}(x_1,x_2,\ldots,x_n)

  特别地,当 X1,X2,,XnX_1,X_2,\ldots,X_n 相互独立时,有 L(θx1,x2,,xn)=fθ(x1,x2,,xn)=i=1nfθ(xi)\mathcal{L}(\theta|x_1,x_2,\ldots,x_n) = f_{\theta}(x_1,x_2,\ldots,x_n)=\prod_{i=1}^nf_{\theta}(x_i)

  若 DD 是离散分布,fθf_{\theta} 即是在参数为 θ\theta 时观测到这一采样的概率;若其是连续分布,fθf_{\theta} 则为 X1,X2,,XnX_1,X_2,\ldots,X_n 联合分布的概率密度函数在观测值处的取值。一旦我们获得 fθ(x1,x2,,xn)f_{\theta}(x_1,x_2,\ldots,x_n),我们就能求得一个关于 θ\theta 的估计。

  极大似然估计会寻找关于 θ\theta 的最可能的值 (即,在所有可能的 θ\theta 取值中,寻找一个值使这个采样的“可能性”最大化)。对于其想要寻找的 θ^\hat{\theta},即可以形式化表达为 :

θ^=arg maxθΘL(θx1,x2,,xn)=arg maxθΘL(θx)\hat{\theta}=\argmax_{\theta\in\Theta}\mathcal{L}(\theta|x_1,x_2,\ldots,x_n)=\argmax_{\theta\in\Theta}\mathcal{L}(\theta|\mathbf{x})

1.3 一些性质和补充 : 对数化计算技巧

  通常情况下,可以进行这样的划分 i=1mDi (m1)=i=1n{Xi}, DiDj= (ij)\bigcup_{i=1}^{m}D_i\ (m\ne1)=\bigcup_{i=1}^n\{X_i\},\ D_i\bigcap D_j=\empty\ (i\ne j),满足对于 XuDi,XvDj(ij),Xu ⁣ ⁣ ⁣Xv\forall X_u\in D_i,\forall X_v\in D_j (i\ne j), X_u\perp \!\!\! \perp X_v,这样就会不可避免的用到乘法原理,由此引出以下的对数化技巧 :

(θx)=lnL(xθ)=i=1mlnfθ(Di)\ell(\theta|\mathbf{x})=\ln\mathcal{L}(x|\theta)=\sum_{i=1}^m\ln f_{\theta}(D_i)

  由于对数是单调函数,因此最大值 (xθ)\ell(x|\theta)L(xθ)\mathcal{L}(x|\theta) 取最大时的 θ^\hat{\theta} 是相同的。如果 (xθ)\ell(x|\theta) 可微于 Θ\Theta,出现最大值 (或最小值) 的必要条件是 :

θi=0,i[k],θRk\frac{\partial \ell}{\partial\theta_i}=0,\forall i\in [k],\theta\in \mathbb{R}^k

  称之为似然方程组。对于某些模型,这些方程可以显式求解 θ^\hat{\theta},但通常没有已知或可用的最大化问题的闭式解,并且 MLE 只能通过数值优化找到。另一个问题是在有限样本中,似然方程可能存在多个根。是否识别根 θ^\hat{\theta} 的似然方程确实是一个 (局部) 最大值取决于二阶偏导数和交叉偏导数的矩阵,即所谓的 Hessian 矩阵

H(θ^)=[2θ12θ=θ^2θ1θ2θ=θ^2θ1θkθ=θ^2θ2θ1θ=θ^2θ22θ=θ^2θ2θkθ=θ^2θkθ1θ=θ^2θkθ2θ=θ^2θk2θ=θ^],\mathbf{H}(\widehat{\theta})=\left[\begin{array}{cccc} \left.\frac{\partial^2 \ell}{\partial \theta_1^2}\right|_{\theta=\widehat{\theta}} & \left.\frac{\partial^2 \ell}{\partial \theta_1 \partial \theta_2}\right|_{\theta=\widehat{\theta}} & \cdots & \left.\frac{\partial^2 \ell}{\partial \theta_1 \partial \theta_k}\right|_{\theta=\widehat{\theta}} \\ \left.\frac{\partial^2 \ell}{\partial \theta_2 \partial \theta_1}\right|_{\theta=\widehat{\theta}} & \left.\frac{\partial^2 \ell}{\partial \theta_2^2}\right|_{\theta=\widehat{\theta}} & \cdots & \left.\frac{\partial^2 \ell}{\partial \theta_2 \partial \theta_k}\right|_{\theta=\widehat{\theta}} \\ \vdots & \vdots & \ddots & \vdots \\ \left.\frac{\partial^2 \ell}{\partial \theta_k \partial \theta_1}\right|_{\theta=\widehat{\theta}} & \left.\frac{\partial^2 \ell}{\partial \theta_k \partial \theta_2}\right|_{\theta=\widehat{\theta}} & \cdots & \left.\frac{\partial^2 \ell}{\partial \theta_k^2}\right|_{\theta=\widehat{\theta}} \end{array}\right],

  是半负定的 θ^\hat{\theta},因为这表明局部凹性。方便的是,最常见的概率分布——尤其是指数族——是对数凹的。

1.4 一些性质和补充 : 受限参数空间

  虽然似然函数的域 (参数空间) 通常是欧几里德空间的有限维子集,但有时需要将额外的限制纳入估计过程。参数空间可以表示为 Θ={θ:θRk,h(θ)=0}\Theta=\left\{\theta: \theta \in \mathbb{R}^k, h(\theta)=0\right\}

  其中 h(θ)=[h1(θ),h2(θ),,hr(θ)],hRkRrh(\theta)=\left[h_1(\theta),h_2(\theta),\ldots,h_r(\theta)\right],h\in\mathbb{R}^k\rightarrow\mathbb{R}^r,估计真实参数 θΘ\theta\in\Theta 后,作为一个实际问题,意味着找到受约束的似然函数的最大值 h(θ)=0h(\theta)=0

  实践中,通常使用 Lagrange 方法施加限制,给定上述定义的限制,限制似然方程 (无处不在的 Lagrange 插值)

θλh(θ)Tθ=0,h(θ)=0\frac{\partial \ell}{\partial \theta}-\lambda\frac{h(\theta)^T}{\partial \theta}=0, h(\theta)=0

  其中 λ=[λ1,λ2,,λr]T\lambda=\left[\lambda_1,\lambda_2,\ldots,\lambda_r\right]^T 是 Lagrange 乘子。

1.5 和 Bayes 推理进行结合

P(θx)=P(θx1,x2,,xn)=f(x1,x2,,xnθ)P(θ)P(x1,x2,,xn)\mathbb{P}(\theta|\mathbf{x})=\mathbb{P}(\theta|x_1,x_2,\ldots,x_n)=\frac{f(x_1,x_2,\ldots,x_n|\theta)\mathbb{P}(\theta)}{\mathbb{P}(x_1,x_2,\ldots,x_n)}

1.6 负对数似然损失函数 (Negative Log-Likelihood Loss Function, NLL)

  在机器学习中,我们通常使用负对数似然损失函数 (Negative Log-Likelihood Loss Function, NLL) 来表示似然函数。用 Bayes 的角度来看,我们可以将每次预测看作一个概率分布 p(yx;θ)p(y|x;\theta) (在进行概率化处理之后),从 MLE 的角度思考,我们希望找到最大化似然函数的参数 θ\theta,即有

L(θx)=i=1np(yixi;θ)θ^=arg maxθL(θx)θ^=arg maxθi=1nlogp(yixi;θ)θ^=arg minθi=1nlogp(yixi;θ)\begin{aligned} \mathcal{L}(\theta\mid x)&=\prod_{i=1}^n p(y_i|x_i;\theta)\\ &\Rightarrow \hat{\theta}=\argmax_{\theta}\mathcal{L}(\theta\mid x)\\ &\Rightarrow \hat{\theta}=\argmax_{\theta}\sum_{i=1}^n \log p(y_i|x_i;\theta)\\ &\Rightarrow \hat{\theta}=\argmin_{\theta}-\sum_{i=1}^n \log p(y_i|x_i;\theta) \end{aligned}

  即为负对数似然损失函数的推导。

2.熵 (Entropy)

2.1 定义 (Shannon Entropy)

  本篇中所介绍的熵均为信息熵,在信息论中,随机变量的熵是变量可能结果固有的“信息”、“意外”或“不确定性”的平均水平。给定一个离散随机变量 XX,它采用字母表中的值 X\mathcal{X} 并满足映射 p:X[0,1]p:\mathcal{X}\rightarrow[0,1]

H(X):=xXp(x)logp(x)=E[logp(X)]\mathrm{H}(X):=-\sum_{x \in \mathcal{X}} p(x) \log p(x)=\mathbb{E}[-\log p(X)]

  其中 E[logp(X)]\mathbb{E}[-\log p(X)]XX 的期望值。熵是一个非负数,当且仅当 p(x)=0p(x)=0 时,xx 不可能发生时,H(X)=0\mathrm{H}(X)=0。熵的单位是比特 (bit)。熵的一个重要性质是,当 p(x)p(x) 的分布越接近均匀分布时,熵越大。

  简单可以将信息熵与热力学中的熵进行类比,熵是一个描述系统混乱程度的指标,当系统越混乱,熵越大。同时也有的熵的表示采用 ln\ln 进行计算,在此我们统一使用 log\log,毕竟可以使用换底公式证明两种计算方法的等价性。

2.2 信息量的定义

  在信息论中,我们用 I(x)I(x) 表示信息量,它是一个随机变量 xx 的不确定性的度量,它的值越大,不确定性越大。信息量的定义为 :

I(x)=log(p(x))I(x)=-\log(p(x))

  则 Shannon 熵也有对应的表示 :

H(X)=xXp(x)I(x)=E[I(x)]\mathrm{H}(X)=\sum_{x \in \mathcal{X}} p(x)I(x)=\mathbb{E}[I(x)]

2.3 交叉熵损失函数 (Cross-Entropy Loss)

  交叉熵 (Cross Entropy) 是 Shannon 信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。在信息论中,交叉熵是表示两个概率分布 p,qp,q 的差异,其中 pp 表示真实分布,qq 表示预测分布,那么 H(p,q)\mathrm{H}(p,q) 就称为交叉熵 :

H(p,q)=xXpilog1qi=xXp(x)logq(x)\mathrm{H}(p,q)=\sum_{x \in \mathcal{X}}p_i\cdot\log\frac{1}{q_i}=-\sum_{x \in \mathcal{X}} p(x) \log q(x)

2.4 KL 散度 (Kullback-Leibler Divergence)

  相对熵又称 KL 散度,如果我们对于同一个随机变量 xx 有两个单独的概率分布 P(x)P(x)Q(x)Q(x),我们可以使用 KL 散度 (Kullback-Leibler (KL) divergence) 来衡量这两个分布的差异,这个相当于信息论范畴的均方差。

  KL 散度的计算公式 :

DKL(PQ)=xXP(x)logP(x)Q(x)=xXP(x)logP(x)xXP(x)logQ(x)=H(P(x),Q(x))H(P(x))\begin{aligned} D_{KL}(P||Q)&=\sum_{x \in \mathcal{X}} P(x) \log \frac{P(x)}{Q(x)}\\&=\sum_{x \in \mathcal{X}} P(x) \log P(x)-\sum_{x \in \mathcal{X}} P(x) \log Q(x)\\ &=\mathrm{H}(P(x),Q(x))-\mathrm{H}(P(x)) \end{aligned}

  不难发现,当 P(x)=Q(x)P(x)=Q(x)时,DKL(PQ)=0D_{KL}(P||Q)=0,同时 KL 散度和交叉熵都是不满足对称性的,即

DKL(PQ)DKL(QP)H(P(x),Q(x))H(Q(x),P(x))\begin{aligned} D_{KL}(P||Q) &\neq D_{KL}(Q||P)\\ \mathrm{H}(P(x),Q(x)) &\neq \mathrm{H}(Q(x),P(x)) \end{aligned}

2.5 机器学习中 KL 散度的应用

  在机器学习中,我们需要评估标签值 yy 和预测值 aa 之间的差距,使用 KL 散度很适合,即 DKL(ya)D_{KL}(y||a),由于 KL 散度中的 H(P(y))-\mathrm{H}(P(y)) 不变,故在优化过程中,只需要关注交叉熵 H(y,a)\mathrm{H}(y,a) 就可以了。所以一般在机器学习中直接用交叉熵做损失函数来评估模型。

  这种情景下的交叉熵函数常用于逻辑回归 (logistic regression),也就是分类 (classification)。

  而如果我们希望预测更加混乱,则可以将 aa 与均匀分布的 pp 标记进行比较,从而达到相反的效果。

=j=1nyjlnaj\ell=-\sum_{j=1}^ny_j\ln a_j

  上式是单个样本的情况,nn 是分类个数。所以,对于批量样本的交叉熵计算公式是 :

J=i=1mj=1nyijlnaijJ=-\sum_{i=1}^m \sum_{j=1}^n y_{i j} \ln a_{i j}

  其中 mm 是样本数,nn 是分类个数。

Question : 为什么不能使用均方差做为分类问题的损失函数?

  • 回归问题通常用均方差损失函数,可以保证损失函数是个凸函数,即可以得到最优解。而分类问题如果用均方差的话,损失函数的表现不是凸函数,就很难得到最优解。而交叉熵函数可以保证区间内单调。
  • 分类问题的最后一层网络,需要分类函数,Sigmoid 或者 Softmax,如果再接均方差函数的话,其求导结果复杂,运算量比较大。用交叉熵函数的话,可以得到比较简单的计算结果,一个简单的减法就可以得到反向误差。

2.6 互信息 (Mutual Information)

  在概率论和信息论中,两个随机变量的互信息 (mutual Information,MI) 度量了两个变量之间相互依赖的程度。具体来说,对于两个随机变量,MI 是一个随机变量由于已知另一个随机变量而减少的“信息量” (单位通常为比特)。互信息的概念与随机变量的熵紧密相关,熵是信息论中的基本概念,它量化的是随机变量中所包含的“信息量”。

  设随机变量 (X,Y)X×Y(X,Y)\in\mathcal{X}\times\mathcal{Y},若他们的联合分布是 p(x,y)p(x,y),边缘分布分别是 p(x)p(x)p(y)p(y),那么,它们之间的互信息可以定义为 :

I(X;Y)=DKL(p(x,y)p(x)p(y))=xXyYp(x,y)logp(x,y)p(x)p(y)\begin{aligned} I(X;Y)&=D_{KL}(p(x,y)||p(x)\otimes p(y))\\ &=\sum_{x\in\mathcal{X}}\sum_{y\in\mathcal{Y}}p(x,y)\log\frac{p(x,y)}{p(x)p(y)} \end{aligned}

  其中 \otimes 是张量乘法。若 p(x,y)=p(x)p(y)p(x,y)=p(x)\otimes p(y),则 I(X,Y)=0I(X,Y)=0,即两个随机变量相互独立。

  同时互信息 I(X,Y)=I(Y,X)I(X,Y)=I(Y,X) 满足对称性。

  当 XXYY 连续时,互信息可以用积分的形式表示 :

I(X,Y)=YXp(x,y)logp(x,y)p(x)p(y)dxdyI(X,Y)=\int_Y\int_Xp(x,y)\log\frac{p(x,y)}{p(x)p(y)}\mathrm{d}x\mathrm{d}y

2.7 互信息的性质证明

  下面我们对于上面示意图中的直观理解进行证明。

I(X;Y)=DKL(p(x,y)p(x)p(y))=xXyYp(x,y)logp(x,y)p(x)p(y)=x,yp(x,y)logp(x,y)p(x)x,yp(x,y)logp(y)=x,yp(x)p(yx)logp(yx)x,yp(x,y)logp(y)=xp(x)(yp(yx)logp(yx))ylogp(y)(xp(x,y))=xp(x)H(YX=x)ylogp(y)p(y)=H(YX)+H(Y)=H(Y)H(YX)=H(X)H(XY)=H(X)+H(Y)H(X,Y)\begin{aligned} I(X;Y)&=D_{KL}(p(x,y)||p(x)\otimes p(y))\\ &=\sum_{x\in\mathcal{X}}\sum_{y\in\mathcal{Y}}p(x,y)\log\frac{p(x,y)}{p(x)p(y)}\\ &=\sum_{x,y}p(x,y)\log\frac{p(x,y)}{p(x)}-\sum_{x,y}p(x,y)\log p(y)\\ & =\sum_{x, y} p(x) p(y \mid x) \log p(y \mid x)-\sum_{x, y} p(x, y) \log p(y) \\ & =\sum_x p(x)\left(\sum_y p(y \mid x) \log p(y \mid x)\right)\\&\quad -\sum_y \log p(y)\left(\sum_x p(x, y)\right) \\ & =-\sum_x p(x) H(Y \mid X=x)-\sum_y \log p(y) p(y) \\ & =-H(Y \mid X)+H(Y) \\ & =H(Y)-H(Y \mid X)\\ &=H(X)-H(X\mid Y)\\ & =H(X)+H(Y)-H(X,Y)\\ \end{aligned}

3 变分推断 (Variational Inference)

3.1 变分推断的基本思想

  英文版的 Wikipedia 上将变分推断称为 Variational Bayesian Inference,因为其中使用了 Bayes 推断的思想。方便起见,我们以变分推断称呼它。

  变分贝叶斯方法是一系列用于逼近贝叶斯推理和机器学习中出现的难处理积分的技术。它们通常用于由观察变量 (通常称为“数据”) 以及未知参数和潜在变量组成的复杂统计模型,三种随机变量之间存在各种关系,如图形模型所描述的那样。一言以蔽之,即可以用 JKRD 的叙述来进行解读。

用一个简单分布拟合另一个复杂分布

  即当我们遇见棘手的复杂分布的刻画的时候,用一个简单分布来对其进行拟合。

3.2 变分推断的定义与推导

  在变分推理中,在给定 X\mathrm{X} 的前提下,一组未观察变量的后验分布 Z={Z1,Z2,,Zn}\mathrm{Z}=\{Z_1,Z_2,\ldots,Z_n\} 可用变分推断 Q(Z)Q(\mathrm{Z}) 进行逼近,即 :

Q(Z)P(ZX)Q(\mathrm{Z})\approx P(\mathrm{Z}\mid\mathrm{X})

  分布 Q(Z)Q(\mathrm{Z}) 被限制为属于比 P(ZX)P(\mathrm{Z}\mid\mathrm{X}) 形式更简单的分布族 (例如高斯分布族)。

  相似度 (或不相似度) 是用距离函数 (或称之为不相似度函数) d(Q;P)d(Q;P),因此最小化 d(Q;P)d(Q;P)

3.3 使用 KL 散度进行度量

  在上面关于 KL 散度的介绍中我们知道 KL 散度类似于距离 (但是因为不对称性不能称之为距离),当 KL 散度越小的时候,我们可以认为两个分布越近似。因此有

DKL(QP)ZQ(Z)logQ(Z)P(ZX)=EZ[logQ(Z)P(ZX)]Q(Z)=arg minQ(Z)DKL(Q(Z)P(ZX))D_{\mathrm{KL}}(Q \| P) \triangleq \sum_{\mathbf{Z}} Q(\mathbf{Z}) \log \frac{Q(\mathbf{Z})}{P(\mathbf{Z} \mid \mathbf{X})}=\mathbb{E}_{\mathbf{Z}}\left[\log \frac{Q(\mathbf{Z})}{P(\mathbf{Z} \mid \mathbf{X})}\right]\\ Q^*(\mathbf{Z})=\argmin_{Q(\mathbf{Z})}D_{\mathrm{KL}}(Q(\mathbf{Z}) \| P(\mathbf{Z} \mid \mathbf{X}))

3.4 证据下界 (Evidence lower bound) 的引入

  然而上式中的 P(ZX)P(\mathbf{Z} \mid \mathbf{X}) 是难以计算的 :

P(ZX)=P(Z,X)P(X)=P(Z,X)ZP(XZ)P(Z)dZP(\mathbf{Z} \mid \mathbf{X})=\frac{P(\mathbf{Z},\mathbf{X})}{P(\mathbf{X})}=\frac{P(\mathbf{Z},\mathbf{X})}{\int_{\mathbf{Z}}P(\mathbf{X}\mid\mathbf{Z})P(\mathbf{Z})\mathrm{d}\mathbf{Z}}

  计算难点主要在于观测变量的边缘分布 P(X)P(\mathbf{X}) (也被称作证据(evidence))。如果隐变量维度很高,那么计算开销将非常大。因此进行一些改进 (对 KL 散度的刀法)

DKL(QP)=ZQ(Z)[logQ(Z)P(Z,X)+logP(X)]=ZQ(Z)[logQ(Z)logP(Z,X)]+ZQ(Z)[logP(X)]=EZ[logQ(Z)logP(Z,X)]+EZ[logP(X)]EZ[logP(X)]=DKL(QP)EZ[logQ(Z)logP(Z,X)]DKL(QP)+L(Q)\begin{aligned} D_{\mathrm{KL}}(Q \| P)&=\sum_{\mathbf{Z}} Q(\mathbf{Z})\left[\log \frac{Q(\mathbf{Z})}{P(\mathbf{Z}, \mathbf{X})}+\log P(\mathbf{X})\right]\\&=\sum_{\mathbf{Z}} Q(\mathbf{Z})[\log Q(\mathbf{Z})-\log P(\mathbf{Z}, \mathbf{X})]+\sum_{\mathbf{Z}} Q(\mathbf{Z})[\log P(\mathbf{X})]\\ &=\mathbb{E}_{\mathbf{Z}}\left[\log Q(\mathbf{Z})-\log P(\mathbf{Z}, \mathbf{X})\right]+\mathbb{E}_{\mathbf{Z}}\left[\log P(\mathbf{X})\right]\\ \mathbb{E}_{\mathbf{Z}}\left[\log P(\mathbf{X})\right]&=D_{\mathrm{KL}}(Q \| P)-\mathbb{E}_{\mathbf{Z}}\left[\log Q(\mathbf{Z})-\log P(\mathbf{Z}, \mathbf{X})\right]\\ &\triangleq D_{\mathrm{KL}}(Q \| P)+\mathcal{L}(Q) \end{aligned}

  其中我们将 L(Q)\mathcal{L}(Q) 定义为证据下界 (Evidence lower bound)。我们可以将计算 KL 散度最小化的问题转化为计算证据下界 L(Q)\mathcal{L}(Q) 最大化的问题 :

DKL(QP)ZQ(Z)logQ(Z)P(ZX)=EZ[logQ(Z)P(ZX)]Q(Z)=arg minQ(Z)DKL(Q(Z)P(ZX))Q(Z)=arg maxQ(Z)L(Q)D_{\mathrm{KL}}(Q \| P) \triangleq \sum_{\mathbf{Z}} Q(\mathbf{Z}) \log \frac{Q(\mathbf{Z})}{P(\mathbf{Z} \mid \mathbf{X})}=\mathbb{E}_{\mathbf{Z}}\left[\log \frac{Q(\mathbf{Z})}{P(\mathbf{Z} \mid \mathbf{X})}\right]\\ Q^*(\mathbf{Z})=\argmin_{Q(\mathbf{Z})}D_{\mathrm{KL}}(Q(\mathbf{Z}) \| P(\mathbf{Z} \mid \mathbf{X}))\\ \Rightarrow Q^*(\mathbf{Z})=\argmax_{Q(\mathbf{Z})}\mathcal{L}(Q)

  正常情况这篇文章应该就要到此为止了,但是为了防止如下的情况,我打算再重新对于论文《Bayesian Invariant Risk Minimization》中的部分公式进行重新复现。(不得不再次狠狠羡慕一波林博士扎实的数理基础,以及为了配这个图我甚至用超分辨率把这个表情包扩了8倍Orz)

4.实战 Bayesian Invariant Risk Minimization 中的变分推断

  BIRM 的初步思路是将负对数似然代入 InvRat 中

minw,umax{we,eEtr}eEtr[Re(w,u)+λ(Re(w,u)Re(we,u))]minueEp(wDu)[lnp(Dew,u)]+λ(Ep(wDu)[lnp(Dew,u)]Ep(weDue)[lnp(Dewe,u)])=maxueEqu(w)[lnp(Dew,u)]+λ(Equ(w)[lnp(Dew,u)]Eque(we)[lnp(Dewe,u)])\min_{w,u}\max_{\{w^e,\forall e\in\mathcal{E}_{tr}\}}\sum_{e\in\mathcal{E}_{tr}}\left[\mathcal{R}^e(w,u)+\lambda(\mathcal{R}^e(w,u)-\mathcal{R}^e(w^e,u))\right]\\ \Rightarrow\min_u\sum_e\mathbb{E}_{p(w|\mathcal{D}_u)}\left[-\ln p(\mathcal{D}^e|w,u)\right]+\lambda\left(\mathbb{E}_{p(w|\mathcal{D}_u)}\left[-\ln p(\mathcal{D}^e|w,u)\right]-\mathbb{E}_{p(w^e|\mathcal{D}_u^e)}\left[-\ln p(\mathcal{D}^e|w^e,u)\right]\right)\\ =\max_u\sum_e\mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]+\lambda\left(\mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]-\mathbb{E}_{q_u^e(w^e)}\left[\ln p(\mathcal{D}^e|w^e,u)\right]\right)

  用 KL 散度操作一波,qu(w)p(wDu)q_u(w)\approx p(w|\mathcal{D}_u)que(we)p(weDue)q^e_u(w^e)\approx p(w^e|\mathcal{D}_u^e),是给定 Du\mathcal{D}_uDue\mathcal{D}_u^e 的分类器的近似后验分布。有以下转换

Eque(we)[lnp(Dewe,u)]=lnp(Dewe,u)que(we)dwe,Equ(w)[lnp(Dew,u)]=lnp(Dew,u)qu(w)dw\mathbb{E}_{q_u^e(w^e)}\left[\ln p(\mathcal{D}^e|w^e,u)\right]=\int\ln p(\mathcal{D}^e|w^e,u)q_u^e(w^e){\rm d}w^e,\\ \mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]=\int\ln p(\mathcal{D}^e|w,u)q_u(w){\rm d}w

在这里,我们通过变分推理使用 que(we)q_u^e(w^e)qu(w)q_u(w) 来近似它们。给定一个分布族 Q\mathcal{Q},我们通过找到最大化证据下界 (ELBO) 的最优 qQq\in\mathcal{Q} 来近似后验分布。估计 que(we)q_u^e(w^e) 的目标函数是 :

que(we)=arg maxqQEq[lnp(Dew,u)KL(qp0(w))]q_u^e(w^e)=\underset{q'\in\mathcal{Q}}{\argmax}\mathbb{E}_{q'}\left[\ln p(\mathcal{D}^e|w,u)-\text{KL}(q'||p_0(w))\right]
KL(qp0(w))=iq(i)lnq(i)p0(i)=q(w)lnq(w)p0(w)dw\text{KL}(q'||p_0(w))=\sum_iq'(i)\ln\frac{q'(i)}{p_0(i)}=\int q'(w)\ln\frac{q'(w)}{p_0(w)}{\rm d}w

  其中第一项是最大化后验分布的预期对数似然,第二项旨在保持 qq' 接近先验 p0(w)p_0(w)。类似地,获得 qu(w)q_u(w) 的目标函数是 :

qu(w)=arg maxqQeEq[lnp(Dew,u)KL(qp0(w))]q_u(w)=\underset{q'\in\mathcal{Q}}{\argmax}\sum_e\mathbb{E}_{q'}\left[\ln p(\mathcal{D}^e|w,u)-\text{KL}(q'||p_0(w))\right]

参考资料 (References)

本文正在参加 人工智能创作者扶持计划