Bayesian Invariant Risk Minimization 论文阅读

1,416 阅读16分钟

0 论文信息

1 背景介绍

  分布迁移下的泛化是机器学习的一个公开挑战。不变风险最小化 (IRM) 是通过提取不变特征来解决这一问题的一个很有前途的框架。然而,尽管 IRM 具有广阔的应用前景和广泛的应用前景,近年来的研究却发现其在深度模型中的应用效果并不理想。我们认为,失败的主要原因是深度模型倾向于过度拟合数据。具体而言,我们的理论分析表明,当过拟合发生时,IRM 退化为经验风险最小化 (ERM)。我们的经验证据也提供了支持 : 在典型环境下工作良好的 IRM 方法显著恶化,即使我们稍微扩大模型大小或减少训练数据。

2 IRM及其过拟合陷阱

2.1 Invariant Risk Minimization (IRM)

预备知识 在整篇论文中,大写字母 X 和 Y 表示随机变量;小写字母 x、y 和 w 表示样本和参数。我们假设有一组多个环境 E\mathcal{E},可以从中提取数据。在训练过程中,我们可以接触到一系列的环境,EtrE\mathcal{E}_{tr}\subset\mathcal{E};每个环境 eEtre\in\mathcal{E}_{tr} 包含 nen^e个样本,记为 De={(xie,yie)}i=1ne\mathcal{D}^e\overset{\triangle}{=}\left\{\left(x^e_i, y^e_i\right)\right\}^{n_e}_{i=1}。让 X\mathcal{X}Y\mathcal{Y} 是 X 和 Y 的空间。我们的目标是学习一个函数 f:XYf:\mathcal{X}\rightarrow\mathcal{Y},它在给定 X 的情况下预测 Y。这里的 ff 由分类器 gw()g_w(\cdot) 和特征提取器 hu()h_u(\cdot) 组成,参数分别为 w 和 u。域泛化的任务是寻找最优 w 和 u,使最坏环境的损失最小 :

minw,usupeERe(w,u)(1)\min_{w,u}\sup_{e\in\mathcal{E}}\mathcal{R}^e(w,u) \tag{1}

  在这里,Re(w,u)\mathcal{R}^e(w,u) 是来自 e 的数据的负对数似然,形式上,我们有 :

Re(w,u)=lnp(Dew,u)=i=1nelnp(yiew,hu(xie))\color{red}\mathcal{R}^e(w,u)=-\ln p(\mathcal{D}^e|w,u)=-\sum_{i=1}^{n^e}\ln p(y^e_i|w,h_u(x_i^e))

  也就是说,我们的目标是从 E\mathcal{E} 中学习最优 w 和 u,使最坏环境的可能性最大化。我们只考虑 w、u 是很好的指定的情况,如 Re(w,u)0\mathcal{R}^e(w,u)\ge0 对所有 w、u 都成立。

不变风险最小化 (IRM) IRM 旨在解决以下目标,实现式 (1) :

minw,ueEtrRe(w,u)s.t.warg minweRe(we,u),eEtr(2)\color{red}\min_{w,u}\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(w,u)\\ \text{s.t.} w\in\underset{w^e}{\argmin}\mathcal{R}^e(w^e,u),\forall e\in\mathcal{E}_{tr}\tag{2}

  式 (2) 中定义的 IRM 试图学习 hu()h_u(\cdot) 的特征表示,该特征表示可以导出一个分类器 gw()g_w(\cdot),该分类器同时对所有训练环境都是最优的。要做到这一点,hu()h_u(\cdot) 应该摒弃虚假特征。

IRMv1 由于式 (2) 是一个具有挑战性的双层优化问题,原作者提出 IRMv1 来近似式 (2) 的解。IRMv1如下所示 :

minw,ueEtrRe(w,u)+λwRe(w,u)2(3)\min_{w,u}\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(w,u)+\lambda\lVert\nabla_w\mathcal{R}^e(w,u)\rVert^2 \tag{3}

  除了 IRMv1 之外,最近还出现了其他几个 IRM 的优秀变体 : InvRat 通过一个最小-最大程序来估计惩罚;REx 使用不同环境下损失的方差作为惩罚。由于篇幅所限,详细描述请读者参考原文。

2.2 过拟合的陷阱

  在本节中,我们从理论上分析了过拟合发生时 IRM 的行为。我们的结果表明,当模型记忆训练数据时,式 (2) 中 IRM 的不变约束是成立的。那么 IRM 将不再对学习不变特征提供任何保证。我们的分析基于以下假设 :

假设 1 (有限样本量) 训练环境和样本的数量是有限的 : Etr<|\mathcal{E}_{tr}|<\inftyDe=ne<,eEtr|\mathcal{D}_e|=n^e<\infty,\forall e\in\mathcal{E}_{tr}

假设 2 (容量充足) 参数w和u有足够的能力拟合训练数据 : 存在 wˉ\bar{w}uˉ\bar{u},使 eEtr,Re(wˉ,uˉ)=0\forall e\in \mathcal{E}_{tr},\mathcal{R}^e(\bar{w},\bar{u})=0

  假设 1 在实践中是成立的,因为我们只能访问来自几个环境的有限的训练数据。假设 2 也与最近关于过参数化神经网络的研究结果一致;例如,有论文表明,即使存在强正则化,大型神经网络也可以记住所有的训练数据。

  然后我们继续定义过拟合区域。

定义1 (过拟合区域) 过拟合区域 Ω\Omega,为满足假设 2 的 wˉ\bar{w}uˉ\bar{u} 的集合 :

Ω:={wˉ,uˉRe(wˉ,uˉ)=0,eEtr}\Omega:=\left\{\bar{w},\bar{u}|\mathcal{R}^e(\bar{w},\bar{u})=0,\forall e\in\mathcal{E}_{tr}\right\}

命题 1 (通用 IRM 的失效) 在假设 1 和 2 下,IRM 在 Ω\Omega 退化为 ERM。此外,Ω\Omega 中的任何元素都是式 (2) 中定义的 IRM 的解。

证明 首先证明了 Ω\Omega 中的任意元素都是 (2) 中定义的 IRM 的静止点。设 (wˉ,uˉ)(\bar{w},\bar{u})Ω\Omega 中的任意元素。根据定义 1,我们有 :

Re(wˉ,uˉ)=0,eEtr(12)\mathcal{R}^e(\bar{w},\bar{u})=0,\forall e\in\mathcal{E}_{tr} \tag{12}

  注意到

Re(w,u)0,eEtr,w,u(13)\mathcal{R}^e(w,u)\ge0,\forall e\in\mathcal{E}_{tr},w,u \tag{13}

  所以它遵循 :

Re(w,uˉ)0,eEtr,w\mathcal{R}^e(w,\bar{u})\ge0,\forall e\in\mathcal{E}_{tr},w

  则

wˉarg minRe(w,uˉ)0,eEtr\bar{w}\in\argmin\mathcal{R}^e(w,\bar{u})\ge0,\forall e\in\mathcal{E}_{tr}

  所以 (wˉ,uˉ)(\bar{w},\bar{u}) 对 (2) 中的约束进行分层。同时,(12) 和 (13) 已经足以证明 (wˉ,uˉ)(\bar{w},\bar{u}) 是目标的最小值。然后我们得出结论 (wˉ,uˉ)(\bar{w},\bar{u}) 是 (2) 中定义的 IRM 的驻点解。

  IRM 在 Ω\Omega 中退化为 ERM 的第一个论点直接来自上面的证明。

  假设存在另一个与 (2) 中的约束匹配的 (w,u)(w',u') 集合

Ω:={(w,u)warg minwRe(w,u),eEtr,(w,u)∉Ω}\Omega':=\left\{(w',u')|w\in\underset{w}{\argmin}\mathcal{R}^e(w,u'),\forall e\in\mathcal{E}_{tr},(w',u')\not\in\Omega\right\}

  请注意,为了简单起见,Ω\Omega 中的元素被排除在 Ω\Omega' 之外。所以 ΩΩ=\Omega\cap\Omega'=\empty 并且 ΩΩ\Omega\cup\Omega' 包括所有满足 (2) 中约束的 (w,u)(w, u)。由 (12) 和 (13) 我们知道

(w,u)Ω,(wˉ,uˉ)Ω,eEtr,Re(w,u)>Re(wˉ,uˉ)=0\forall (w',u')\in\Omega',(\bar{w},\bar{u})\in\Omega,\exists e\in\mathcal{E}_{tr},\\ \mathcal{R}^e(w',u')>\mathcal{R}^e(\bar{w},\bar{u})=0

  它遵循 :

(w,u)Ω,(wˉ,uˉ)ΩeRe(w,u)>eRe(wˉ,uˉ)\forall (w',u')\in\Omega',(\bar{w},\bar{u})\in\Omega\\ \sum_e\mathcal{R}^e(w',u')>\sum_e\mathcal{R}^e(\bar{w},\bar{u})

  这意味着 Ω\Omega 中的任何元素 (wˉ,uˉ)(\bar{w},\bar{u}) 的目标都小于 Ω\Omega' 中的任何元素 (w,u)(w',u')。这意味着 (2) 中的 IRM 不会选择 Ω\Omega' 中的任何元素。

  我们已经知道 ΩΩ\Omega\cup\Omega'(w,u)(w, u) 匹配约束的集合。因此 IRM 将在 Ω\Omega 中选择任意元素。值得注意的是,根据假设 2 和定义 1,我们在 Ω\Omega 中没有施加任何不变约束。然后我们在 Ω\Omega 中证明了我们的第一个论点,即 IRM 退化为 ERM。\Box

  命题 1 表明,任何过拟合训练数据的模型都是式 (2) 中 IRM 的解,无论该模型是否使用了伪特征。这样的模型在不可见的测试环境中可能表现得非常糟糕。不幸的是,这种过拟合现象在深度神经网络中很常见。

与现有理论的联系 之前的一些工作分析了 IRM 的一些理论性质。有的表明 IRM 的样本复杂度比 ERM 差。还有工作显示了非线性函数 IRM 的难度。与之前的这些工作相比,我们的理论具有以下优点 :

  • 我们的理论直接作用于IRM的定义,它适用于各种IRM的变体。相反,这些工作只关注IRM的一个变体 IRMv1。他们的理论是否适用于其他变体仍有待探索。
  • 一些工作限制了一些特殊非线性模型的讨论,其中函数值在高密度区域边界上跳跃。很难验证这种情况是否足够普遍,可以覆盖在实践中使用的模型,即神经网络。相比之下,我们的理论建立在非常温和和可验证的假设之上。

  下面的推论1表明,IRMv1 在过拟合情况下也很难学习不变特征。

假设 3 (可微性) Re(w,u)\mathcal{R}^e(w,u) 对 w、u是可微的

推论 1 (IRMv1失效) 在已有假设 1、2、3 的情况下,(wˉ,uˉ)Ω\forall (\bar{w},\bar{u})\in\Omega 下,有以下等式 :

(wˉ,uˉ)arg minw,ueEtrRe(w,u)+λwRe(w,u)2(\bar{w},\bar{u})\in\underset{w,u}{\argmin}\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(w,u)+\lambda\lVert\nabla_w\mathcal{R}^e(w,u)\rVert^2

证明 根据定义

eEtrRe(w,u)+λwRe(w,u)20\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(w,u)+\lambda\lVert\nabla_w\mathcal{R}^e(w,u)\rVert^2\ge0

  因此,如果 wˉ,uˉ\bar{w},\bar{u} 满足

eEtrRe(wˉ,uˉ)+λwRe(wˉ,uˉ)2=0\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(\bar{w},\bar{u})+\lambda\lVert\nabla_w\mathcal{R}^e(\bar{w},\bar{u})\rVert^2=0

  则有

(wˉ,uˉ)arg minw,ueEtrRe(w,u)+λwRe(w,u)2(\bar{w},\bar{u})\in\underset{w,u}{\argmin}\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(w,u)+\lambda\lVert\nabla_w\mathcal{R}^e(w,u)\rVert^2

  根据假设 1 和 2,我们有

eEtrRe(wˉ,uˉ)=0\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(\bar{w},\bar{u})=0

  根据假设 3,wRe(wˉ,uˉ)\nabla_w\mathcal{R}^e(\bar{w},\bar{u}) 存在。

  如果 wRe(wˉ,uˉ)=v0\nabla_w\mathcal{R}^e(\bar{w},\bar{u})=v\ne0,我们可以得到

limϵ0Re(wˉ+ϵv,uˉ)Re(wˉ,uˉ)ϵ=v2\lim_{\epsilon\to0}\frac{\mathcal{R}^e(\bar{w}+\epsilon v,\bar{u})-\mathcal{R}^e(\bar{w},\bar{u})}{\epsilon}=\lVert v\rVert^2

  那么对于 v22\frac{\lVert v\rVert^2}{2},可以得到存在 σ>0\sigma>0,使得对于所有 t(σ,σ)t\in(-\sigma,\sigma),有

Re(wˉ+tv,uˉ)Re(wˉ,uˉ)t>v22\frac{\mathcal{R}^e(\bar{w}+t v,\bar{u})-\mathcal{R}^e(\bar{w},\bar{u})}{t}>\frac{\lVert v\rVert^2}{2}

  选择 t=σ2t=-\frac{\sigma}{2}

Re(wˉσv2,uˉ)<Re(wˉ,uˉ)σv24=σv24<0\mathcal{R}^e(\bar{w}-\frac{\sigma v}{2},\bar{u})<\mathcal{R}^e(\bar{w},\bar{u})-\frac{\sigma\lVert v\rVert^2}{4}=-\frac{\sigma\lVert v\rVert^2}{4}<0

  这与 Re\mathcal{R}^e 的定义相矛盾。

  因此,wRe(wˉ,uˉ)=0\lVert\nabla_w\mathcal{R}^e(\bar{w},\bar{u})\rVert=0.

(wˉ,uˉ)arg minw,ueEtrRe(w,u)+λwRe(w,u)2(\bar{w},\bar{u})\in\underset{w,u}{\argmin}\sum_{e\in\mathcal{E}_{tr}}\mathcal{R}^e(w,u)+\lambda\lVert\nabla_w\mathcal{R}^e(w,u)\rVert^2 \\ \Box

  推论 1 表明,任何经验损失为零的模型也是 IRMv1 的最优解。值得注意的是,这个模型仍然可以依赖虚假的特征。推论 1 的证明是命题 1 的直接结果。我们也可以证明类似的失败案例 InvRat、REx。

2.3 相关的经验证据

图 1.在 CMNIST 上使用 3 层不同隐藏维度的 MLP 训练 ERM 的图示。IRM (REx) 的惩罚是衡量的,但不适用于目标。随着 ERM 训练的进行,IRM 惩罚衰减到零,而非不变指标显示模型中存在大量虚假特征。随着更大的模型和更少的训练数据,IRM 惩罚消失得更快。

  如上所述,如果模型记住了数据,IRM 将失败。为了看到这一点,我们可视化 ERM 的训练过程。计算 IRM 的惩罚,但不应用于训练目标。同时,我们还通过非不变指标估计模型中包含的虚假特征。非不变指标定义为预测容易受到虚假特征变化影响的测试样本的百分比。零非不变指标意味着模型完全忽略了虚假特征,而较大的非不变指标代表更多虚假特征的使用。图 1 显示了我们在 CMNIST 上使用 3 层 MLP 训练 ERM 模型时的 IRM 惩罚和非不变指标。一开始,随机初始化的网络不包含虚假特征,因此非不变指标和 IRM 惩罚处于低水平。随着训练的进行,模型会快速学习虚假特征,增加非不变指标和 IRM 惩罚。然后,随着模型在虚假特征之后学习不变特征,非不变指标下降并稳定到最后。随着模型开始记忆数据,IRM 惩罚在最后消失,但非不变指标保持在 60%-70%。换句话说,该模型仍然严重依赖虚假特征,而 IRM 惩罚无法检测到它。图 1 进一步显示,随着模型容量的增加或数据集大小的减小,IRM 惩罚消失得更快。经验现象与我们在之前 2.2 中的理论结果一致:IRM 在过度拟合时失败。

3 贝叶斯不变风险最小化

  在之前的 2.2 中,已经证明过拟合对 IRM 是有害的。贝叶斯推理是一种众所周知的减轻过拟合的方法,并且被证明可以在模型错误指定的情况下实现最佳样本复杂度。在本节中,我们通过结合贝叶斯原理提出了贝叶斯不变风险最小化 (BIRM),这是 IRM 的一种新变体。

3.1 启发和公式

图 2.学习不变和非不变特征的模型示意图。节点 u 表示特征编码器 hu()h_u(\cdot)。节点 wew^e 为给定 Due\mathcal{D}_u^e 的分类器参数的后置,为 hu()h_u(\cdot) 变换的环境 e 的数据分布。节点 w 表示混合环境数据的后验,Du\mathcal{D}_u。(左)当 hu()h_u(\cdot) 对非不变特征进行编码时,每个环境都有唯一的后验分类器参数,该参数依赖于环境指标 e;(右)当 hu()h_u(\cdot) 编码不变特征时,我们的后验与 w 几乎相同,不再依赖于环境指标 e。

  为了启发我们的方法,我们在图 2 中为不变学习问题建立了一个图。节点 u 表示特征提取器 hu()h_u(\cdot)。设 Due\mathcal{D}^e_u 为环境 e 经提取器 hu()h_u(\cdot) 变换后的数据 : Due={hu(xie),ye}i=1n\mathcal{D}^e_u\overset{\triangle}{=}\left\{h_u(x_i^e),y^e\right\}_{i=1}^n。设 Du=e=1EtrDue\mathcal{D}_u\overset{\triangle}{=}\bigcup_{e=1}^{\mathcal{E}_{tr}}\mathcal{D}^e_u 表示来自混合训练环境的数据收集。图 2 中的节点 wew^e 和 w 分别代表 p(weDue)p(w^e|\mathcal{D}^e_u)p(wDu)p(w|\mathcal{D}_u),它们是给出特征表示的分类器的后置。我们在图 2 中添加了斑马条纹来区分 w 和 u 和 wew^e,因为 w 和 u 不依赖于某个环境指数。按照典型平均场变分推理的常见做法,我们假设所有 wew^e 和 w 的先验 p0(w)p_0(w) 相同。

  如果特征提取器 hu()h_u(\cdot) 学习非不变的特征,Due\mathcal{D}_u^e 的数据分布随 e不同而不同。因此后验 p(weDue)p(w^e|\mathcal{D}_u^e) 在不同的环境中是不同的。然后 wew^e 对 e 存在依赖关系,如图 2 (左) 所示。我们进一步有 (weDue)(wDu)(w^e|\mathcal{D}_u^e)\ne (w|\mathcal{D}_u),因为 Due\mathcal{D}_u^e 的数据分布不同于 Du\mathcal{D}_u 的数据分布。在这种情况下,由于 Due\mathcal{D}_u^{e'} 可以是任意的,因此模型不能推广到不可见的环境 e'。

  不变学习的目标是获得对不变特征进行编码的提取函数。在不变表示下,Due\mathcal{D}_u^e 的数据分布对所有 e 都是相同的。因此,对于每个环境,后验 p(weDue)p(w^e|\mathcal{D}_u^e) 应该是接近的,它们都进一步等价于共享后验 : p(weDue)p(wDu)p(w^e|\mathcal{D}_u^e)\approx p(w|\mathcal{D}_u)。图 2 (右) 通过删除节点 wew^e 对节点 e 的依赖来说明这种情况。

  基于上述直觉,我们提出了贝叶斯不变风险最小化 (BIRM) :

maxueEqu(w)[lnp(Dew,u)]+λ(Equ(w)[lnp(Dew,u)]Eque(we)[lnp(Dewe,u)])(4)\max_u\sum_e\mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]+\lambda\left(\mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]-\mathbb{E}_{q_u^e(w^e)}\left[\ln p(\mathcal{D}^e|w^e,u)\right]\right) \tag{4}

  其实在这里可以换一种写法更方便地进行理解,即通过后验使用负对数似然对于loss和惩罚项进行改写

minueEp(wDu)[lnp(Dew,u)]+λ(Ep(wDu)[lnp(Dew,u)]Ep(weDue)[lnp(Dewe,u)])\color{red}\min_u\sum_e\mathbb{E}_{p(w|\mathcal{D}_u)}\left[-\ln p(\mathcal{D}^e|w,u)\right]+\lambda\left(\mathbb{E}_{p(w|\mathcal{D}_u)}\left[-\ln p(\mathcal{D}^e|w,u)\right]-\mathbb{E}_{p(w^e|\mathcal{D}_u^e)}\left[-\ln p(\mathcal{D}^e|w^e,u)\right]\right)

  其中,qu(w)p(wDu)q_u(w)\approx p(w|\mathcal{D}_u)que(we)p(weDue)q^e_u(w^e)\approx p(w^e|\mathcal{D}_u^e),是给定 Du\mathcal{D}_uDue\mathcal{D}_u^e 的分类器的近似后验分布。有以下转换

Eque(we)[lnp(Dewe,u)]=lnp(Dewe,u)que(we)dwe,Equ(w)[lnp(Dew,u)]=lnp(Dew,u)qu(w)dw\mathbb{E}_{q_u^e(w^e)}\left[\ln p(\mathcal{D}^e|w^e,u)\right]=\int\ln p(\mathcal{D}^e|w^e,u)q_u^e(w^e){\rm d}w^e,\\ \mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]=\int\ln p(\mathcal{D}^e|w,u)q_u(w){\rm d}w

  分别是来自环境 e 的数据的 que(we)q_u^e(w^e)qu(w)q_u(w) 的预期对数似然。

  请注意,近似后验 qu(w)q_u(w)que(we)q_u^e(w^e) 显式依赖于 u。式 (4) 中的第一项通过优化 u 来最大化 w 的共享后验 qu(w)q_u(w) 的预期对数似然。它鼓励你保留尽可能多的信息,以使 qu(w)q_u(w) 适应数据分布。式 (4) 的第二项要求你学习不变的特征。如果 hu()h_u(\cdot) 对非不变特征进行编码,则变换后的分布 Due\mathcal{D}_u^e 因环境而异。回想一下,que(we)q_u^e(w^e) 是给定 Due\mathcal{D}_u^e 的后验,qu(w)q_u(w) 是给定 Du\mathcal{D}_u 的后验。所以 que(we)q_u^e(w^e)Due\mathcal{D}_u^e 上可以实现比 qu(w)q_u(w) 更高的似然性。然后我们施加惩罚,要求 hu()h_u(\cdot) 丢弃非不变特征。

  请注意,式 (2) 中 IRM 的普通定义是基于 w 的单点估计,当数据不足时可能高度不稳定。 BIRM 不是点估计,而是由后验分布直接诱导的,不太容易过度拟合。

变分推理 在大型模型中,后验分布的估计并非易事。在这里,我们通过变分推理使用 que(we)q_u^e(w^e)qu(w)q_u(w) 来近似它们。给定一个分布族 Q\mathcal{Q},我们通过找到最大化证据下界 (ELBO) 的最优 qQq\in\mathcal{Q} 来近似后验分布。估计 que(we)q_u^e(w^e) 的目标函数是 :

que(we)=arg maxqQEq[lnp(Dewe,u)KL(qp0(w))](5)q_u^e(w^e)=\underset{q'\in\mathcal{Q}}{\argmax}\mathbb{E}_{q'}\left[\ln p(\mathcal{D}^e|we,u)-\text{KL}(q'||p_0(w))\right] \tag{5}
KL(qp0(w))=iq(i)lnq(i)p0(i)=q(w)lnq(w)p0(w)dw\text{KL}(q'||p_0(w))=\sum_iq'(i)\ln\frac{q'(i)}{p_0(i)}=\int q'(w)\ln\frac{q'(w)}{p_0(w)}{\rm d}w

  其中第一项是最大化后验分布的预期对数似然,第二项旨在保持 qq' 接近先验 p0(w)p_0(w)。类似地,获得 qu(w)q_u(w) 的目标函数是 :

qu(w)=arg maxqQeEq[lnp(Dew,u)KL(qp0(w))](6)q_u(w)=\underset{q'\in\mathcal{Q}}{\argmax}\sum_e\mathbb{E}_{q'}\left[\ln p(\mathcal{D}^e|w,u)-\text{KL}(q'||p_0(w))\right] \tag{6}

  根据变分推理 (平均场近似) 中的常用做法,我们选择因式分解高斯分布,即 Q={N(μ,Σ):μ=[w1,...,wd],Σ=diag(σ1,...,σd)}\mathcal{Q}=\{\mathcal{N}(\mu,\Sigma):\mu=[w_1,...,w_d]^⊤,\Sigma=\text{diag}(\sigma_1,...,\sigma_d)\},其中 d 为分类器参数 w 的维数。先验 p0(w)p_0(w) 设为均值为零的高斯分布 : N(0,σI)\mathcal{N}(0,\sigma I)。式 (5) 和式 (6) 的估计后置记为 : qu(w)=N(μ~,Σ~),que(we)=N(μ~e,Σ~e)q_u(w)=\mathcal{N}(\tilde{\mu},\tilde{\Sigma}),q^e_u(w^e)=\mathcal{N}(\tilde{\mu}^e,\tilde{\Sigma}^e)

  在变分推理的帮助下,我们最终能够优化式 (4)。具体来说,训练过程将在求解式 (5)、(6) 和 (4) 之间迭代。

  下面的命题描述了当我们学习一个不变 u 时 que(we)q^e_u(w^e)qu(w)q_u(w) 的行为。

命题 2 如果 hu()h_u(\cdot) 不提取虚假的特征,随着 nen_e\rightarrow\inftyque(we)Dqu(w)q_u^e(w^e)\overset{\mathcal{D}}{\rightarrow}q_u(w)

Equ(w)[lnp(Dew,u)]Eque(we)[lnp(Dew,u)]0\mathbb{E}_{q_u(w)}\left[\ln p(\mathcal{D}^e|w,u)\right]-\mathbb{E}_{q_u^e(w^e)}\left[\ln p(\mathcal{D}^e|w,u)\right]\rightarrow0

证明 根据定义,当 huh_u 提取不变特征时,数据在各个环境中的分布是相同的。由于模型是明确指定的,假设数据生成分布满足

p(yx)=p(yhu(x),w).p(y|x)=p(y|h_u(x),w_∗).

  然后根据 Bernstein–von Mises 定理,当 nen_e\rightarrow\infty

weN(w,Fe1/ne)w^e\rightarrow\mathcal{N}(w_∗,F_e^{-1}/n_e)

  其中 ww_* 是最优解,FeF_e 是 Fisher 信息矩阵。

  注意,当 p(xe)p(x^e) 是跨环境的不变量时,Fisher 信息矩阵 Fe=FF_e=F 是一个常数矩阵。

  因此,对于任意 e1,e2Etre_1,e_2\in|\mathcal{E}_{tr}|,我们有

we1Dwe2w^{e_1}\overset{\mathcal{D}}{\rightarrow}w^{e_2}

  因为 w 使用所有 D\mathcal{D} 的子集 (只有 nen_e 个样本),我们有

wN(w,F1/n)wDwew\rightarrow\mathcal{N}(w_∗,F^{-1}/n)\\ w\overset{\mathcal{D}}{\rightarrow}w^{e}\\ \Box

  命题 2 表明,如果 hu()h_u(\cdot) 不提取虚假特征,则惩罚为零,BIRM 只考虑模型的经验风险。否则,将诱发惩罚,鼓励 hu()h_u(\cdot) 放弃虚假特征。

3.2 减少方差的重新参数化

  注意到我们使用来自 qu(w)q_u(w)que(we)q_u^e(w^e) 的蒙特卡洛样本来估计式 (4) 中的惩罚项。一个常见的做法是通过重新参数化技巧来绘制样本 :

w=μ~+ϵΣ~,we=μ~e+ϵeΣ~e,e(7)w=\tilde{\mu}+\epsilon\tilde{\Sigma},w^e=\tilde{\mu}^e+\epsilon^e\tilde{\Sigma}^e,\forall e \tag{7}

  其中 ϵ,ϵeN(0,I),ϵϵe,eEtr\epsilon,\epsilon^e\sim\mathcal{N} (0,I),\epsilon\bot\epsilon^e,\forall e\in\mathcal{E}_{tr}。然而,在命题 2 中,这两个期望项很接近,但传统的重参数化方法在训练过程中可能会产生较大的方差。考虑我们收集 K 个样本来估计期望,从 qu(w)q_u(w) 取出 wu,1,...,wu,Kw_{u,1},...,w_{u,K} 和 从 que(we)q^e_u(w^e) 取出 wu,1e,...,wu,Kew^e_{u,1},...,w^e_{u,K};估计惩罚项的计算方法如下 :

JK(u)=1Ki=1Ke[lnp(Dewu,ie,u)lnp(Dewu,i,u)](8)J_K(u)=\frac{1}{K}\sum_{i=1}^K\sum_e\left[\ln p(\mathcal{D}^e|w_{u,i}^e,u)-\ln p(\mathcal{D}^e|w_{u,i},u)\right] \tag{8}

  JK(u)J_K(u) 的方差特征如下 :

命题 3 通过式 (7) 中的常规重参数化,当 nen_e\rightarrow\infty 时,V[JK]cK\mathbb{V}[J_K]\rightarrow \frac{c}{K},其中 c 是一个常数,V[JK]\mathbb{V}[J_K]JKJ_K 的方差。

证明 类似于命题 2,由 Bernstein–von Mises 定理,当 nen_e\rightarrow\infty 时,

wN(w,F1/ne)w\rightarrow\mathcal{N}(w_∗,F^{-1}/n_e)

  其中 ww_* 是最优解,F 是 Fisher 信息矩阵。因此

w=Op(1/ne)lnp(yx,w,u)=Op(1/ne)lnp(Dew,u)=i=1nelnp(yx,w,u)=Op(1)w=O_p(1/n_e)\\ \ln p(y|x,w,u)=O_p(1/n_e)\\ \ln p(\mathcal{D}^e|w,u)=\sum_{i=1}^{n_e}\ln p(y|x,w,u)=O_p(1)\\

  因此

JK=Op(1/K)J_K=O_p(1/K)\\ \Box

  命题 3 表明,在给定 K 的情况下,估计惩罚 JKJ_K 的方差是一个常数。在这种情况下,我们需要一个较大的 K 来使训练算法稳定。此外,在训练接近结束时,惩罚的期望接近于零 (根据命题 2),这意味着方差可以主导惩罚。

  为了解决这个问题,我们提出方差减少重参数化技巧。我们的主要直觉是对 wwwew^e 使用共享的辅助噪声变量 ϵs\epsilon_s,使抽样的随机性在相减后可以相互抵消。具体来说,我们采样 ϵsN(0,I)\epsilon_s\sim\mathcal{N}(0,I),并使用它来参数化 wuw_uwuew^e_u :

w=μ~+ϵsΣ~,we=μ~e+ϵsΣ~e,e(9)w=\tilde{\mu}+\epsilon_s\tilde{\Sigma},w^e=\tilde{\mu}^e+\epsilon_s\tilde{\Sigma}^e,\forall e \tag{9}

  我们将式 (9) 中的重参数化命名为方差减少的重参数化技巧。下面的命题说明了这种方法的优点。

命题 4 通过式 (9) 中的方差减少重参数化,当 nen_e\rightarrow\infty 时,V[JK]0\mathbb{V}[J_K]\rightarrow0,其中 V[JK]\mathbb{V}[J_K]JKJ_K 的方差

证明 通过命题 2 有

weDww_e\overset{\mathcal{D}}{\rightarrow}w

  通过我们的参数化.

wea.s.ww_e\overset{a.s.}{\rightarrow}w

  对于任意 x、y,

lnp(yx,w,u)lnp(yx,we,u)=op(1)JK=op(1/K)=0\ln p(y|x,w,u)-\ln p(y|x,w_e,u)=o_p(1)\\ J-K=o_p(1/K)=0\\ \Box

  将命题 4 与命题 3 进行比较,我们可以看到,方差减少的重参数化可以获得比传统方法更小的方差。

3.3 快速适应

  虽然贝叶斯后验法的引入是直接合理的,但是要在每一步的不同环境下找到式 (5) 的 ELBO 解是非常耗费计算的。我们进一步借鉴 MAML 的快速适应思想,以更有效的方式估计 que(We)q^e_u(W^e)。命题 2 表明,随着训练过程的进行,当 hu()h_u(\cdot) 提取的伪特征较少时,que(we)q^e_u(w^e) 会更接近 qu(w)q_u(w)。这使得 que(w)q^e_u(w) 的快速估计可以如下所示 :

que(we)=N(μμEqu(w)lnp(Dew,u),Σ)(10)q^e_u(w^e)=\mathcal{N}(\mu-\nabla_\mu\mathbb{E}_{q_u(w)}\ln p(\mathcal{D}^e|w,u),\Sigma)\tag{10}

  其中 qu(w)=N(μ,Σ)q_u(w)=\mathcal{N}(\mu,\Sigma)。在这里,que(we)q^e_u(w^e) 的均值 μe\mu e 在来自环境 e 的数据上近似为 μ\mu 的梯度下降一步。快速适应的可行性是基于命题 2 所示 que(we)q^e_u(w^e)qu(w)q_u(w) 的接近性,这使得单步估计是可行的。通过这种方法,我们不需要每次都从头估计 que(we)q^e_u(w^e)

  现有的工作提出了带不确定性的域不变学习 (DILU),它还估计了分类器的分布,以获得更好的 OOD 性能。具体来说,他们从每个环境中随机抽取具有相同标签的样本,并匹配样本的输出。然而,现有的 IRM 工作通常认为这是一个极具挑战性的任务,其中标签是有噪声的。由于标签噪声的存在,DILU 可以强制对齐不同类别样本的预测,这将阻碍因果特征的学习。

3.4 完整算法与代码实现

BIRM 的完整算法总结如下:

  而在上面的 3.2 的部分中,已经将惩罚项化简为 JKJ_KKK 即为采样次数 sampleN,通过一定的技巧将其转化为如下的重参数形式。

class EBD(nn.Module):
    def __init__(self, env_num):
      super(EBD, self).__init__()
      self.embedings = nn.Embedding(env_num, 1)
      self.env_num = env_num
      self.re_init()

    def re_init(self):
      self.embedings.weight.data.fill_(1.)

    def forward(self, e):
      return self.embedings(e.long())
    
    # 随机化嵌入向量
    def re_init_with_noise(self, noise):
        rd = torch.normal(
            torch.Tensor([1.0] * self.env_num),
            torch.Tensor([noise] * self.env_num))
        self.embedings.weight.data = rd.view(-1, 1).cuda()

    sampleN = 10
    train_penalty = 0
    train_logits = mlp(train_x)
    for i in range(sampleN):  
        ebd.re_init_with_noise(flags.prior_sd_coef/flags.data_num)
        train_logits_w = ebd(train_g).view(-1, 1)*train_logits
        train_nll = mean_nll(train_logits_w, train_y)
        grad = autograd.grad(
            train_nll * flags.envs_num, ebd.parameters(),
            create_graph=True)[0]
        train_penalty +=  1/sampleN * torch.mean(grad**2)

4 实验

  暂时略。

5 个人感想

  几天看下来,感觉这个方法整体还是从非常数学的角度展开的推导,非常佩服作者扎实的数学功底,所以在整理时也将作者的这些公式证明从附录中一并整理了。现在自己的数学水平还是有很大差距,多读论文,必要时多学习相关课程补充基本功。