Invariant Risk Minimization 论文阅读

1,020 阅读22分钟

0 论文信息

1 背景介绍

  本文引入了不变风险最小化的方法 (IRM),作为一种学习范例,用于估计多个分布之间的不变相关性。为了实现这一目标,不变风险最小化学习了一种数据的表达,使得在这种数据表达之上的最优分类器可以匹配所有的训练分布。通过理论和实验,我们展示了不变风险最小化学习到的不变性如何与控制数据的因果结构相关联,并实现了分布外的泛化。

2 泛化的多面性

  我们考虑在多个训练环境 eEtre\in\mathcal{E}_{tr} 下收集的数据集 De:={(xie,yie)}i=1neD_e:= \left\{(x^e_i,y^e_i )\right\}^{n_e}_{i=1}。这些环境描述了在不同条件下测量的同一对随机变量。来自环境 e 的数据集 DeD_e 包含根据某个概率分布 P(Xe,Ye)P(X^e,Y^e) 相同且独立分布的示例。然后,我们的目标是使用这些多个数据集来学习预测器 Yf(X)Y\approx f(X),它执行很好地泛化了大量不可见但相关的环境 EtrEall\mathcal{E}_{tr}\subset\mathcal{E}_{all}。也就是说,我们希望最小化

ROOD(f)=maxeEallRe(f)R^{\text{OOD}}(f)=\max_{e\in\mathcal{E}_{all}}R^e(f)

  其中 Re(f):=EXe,Ye[l(f(Xe),Ye)]R^e(f):=\mathbb{E}_{X^e,Y^e}[\mathcal{l}(f(X^e),Y^e)] 是环境 e 下的风险。在这里,所有环境的集合 Eall\mathcal{E}_{all} 包含与我们的变量系统有关的所有可能的实验条件,包括可观察的和假设的。

示例 1 考虑结构方程模型如下

X1Gaussian(0,σ2)YX1+Gaussian(0,σ2)X2Y+Gaussian(0,1).X_1\leftarrow \text{Gaussian}(0,\sigma^2)\\Y\leftarrow X_1+\text{Gaussian}(0,\sigma^2)\\X_2\leftarrow Y+\text{Gaussian}(0, 1).

  然后,要使用最小二乘预测器 Y^e=X1eα^1+X2eα^2\hat{Y}^e=X^e_1\hat{\alpha}_1+X^e_2\hat{\alpha}_2(X1,X2)(X_1,X_2) 预测 Y,即考虑

minYX1α^1X2α^22\min\sum\lVert Y-X_1\hat{\alpha}_1-X_2\hat{\alpha}_2 \rVert^2
  • X1eX^e_1 回归,得到 α^1=1\hat{\alpha}_1=1α^2=0\hat{\alpha}_2=0
  • X2eX^e_2 回归,得到 α^1=0\hat{\alpha}_1=0α^2=2σ(e)22σ(e)2+1\hat{\alpha}_2=\frac{2\sigma(e)^2}{2\sigma(e)^2+1}
  • (X1e,X2e)(X^e_1,X^e_2) 回归,得到 α^1=1σ(e)2+1\hat{\alpha}_1=\frac{1}{\sigma(e)^2+1}α^2=σ(e)2σ(e)2+1\hat{\alpha}_2=\frac{\sigma(e)^2}{\sigma(e)^2+1}

  使用 X1X_1 回归是我们的第一个不变相关性,也即该回归预测效果不依赖于环境 e。相反,第二个和第三个回归的预测效果依赖环境的变化。这些变化的 (虚假的) 相关性不能很好的推广到测试环境中。但并不是所有的不变性都是我们所关心的,比如从空集特征到 Y 的回归是不变的,但却没有预测效果。

  Y^=1×X1+0×X2\hat{Y}=1\times X_1+0\times X_2 是唯一的在所有环境 Eall\mathcal{E}_{all} 中不变的预测规则。进一步,该预测也是跨环境的对目标变量取值的因果解释。换句话说,这对目标变量随输入的变化提供了一种准确的描述。这是令人信服的,因为不变性是一个可检验的量,我们可以通过它发现因果关系。我们将在第4节详细讨论不变性和因果性的关系。但是首先,如何学习得到不变性,因果的回归?我们先回顾现有技术的一些局限性。

  首先,我们可以直接使用所有的训练数据进行学习,使用所有特征来最小化训练误差。这就是传统的 Empirical Risk Minimization (ERM) 方法。在这个例子中,如果训练环境具有很大的 σ2(e)\sigma^2(e),那么 ERM 方法将赋予 X2X_2 一个很大的正系数,这就远离了不变性。

  第二,我们可以最小化 Rrob(f)=maxeEtrRe(f)reR^{rob}(f)=\max_{e\in\mathcal{E}_{tr}}R^e(f) −r_e,一种鲁棒性的学习策略,其中 rer_e 是一个环境基准。设置这些基准为0就表明最小化在不同环境中的最大误差。选择这些基准是为了防止对嘈杂的环境为主导的优化。例如,我们可以选择 re=V[Ye]r_e=\mathbb{V}[Y^e],来最小化不同环境间的最大解释方差。虽然很有价值,但这就等同于鲁棒性的学习会最小化环境训练错误加权平均值。即选择最优的 λe\lambda_e,使得 Rrob=eEtrλeRe(f)R^{rob}=\sum_{e\in \mathcal{E}_{tr}}\lambda_e R^e(f) 最小化。但是对于混合训练环境具有很大的 σ2(e)\sigma ^2(e),会给 X2X_2 赋予较大参数,但是测试环境可能具有较小的 σ2(e)\sigma ^2(e)

命题 2 在 KKT 可微性和限定条件下,λe0\exists\lambda_e\ge0,使得 RrobR^{rob}的最小值是 eEtrλeRe(f)\sum_{e\in\mathcal{E}_{tr}}\lambda_e R^e(f) 的一阶定常点。

  第三,我们可以采取一种自适应策略来估计在所有环境中具有相同分布的数据表达 Φ(X1,X2)\Phi \left(X_{1}, X_{2}\right)。这对于上述例子是不可能的,因为 X1X_1 的分布在不同的环境中是不同的。这就说明了为什么技术匹配的特征分布优势会增加不变性的错误形式。

  第四, 我们可以紧跟这种不变性因果预测技术。这些变量的子集用于回归每一个环境,在所有环境中都会产生相同的回归残差。匹配残差分布不适用于上述例子,因为Y的噪声随环境发生变化。

  总之,对于这个简单的例子都很难找到不变的预测。为了解决这个问题,我们提出了 IRM 方法,这是一种学习范式,可以提取跨多个环境的非线性不变预测变量,从而实现 OOD 泛化。

3 不变风险最小化算法

  用统计学的话讲,我们的目标就是学习不同训练环境中不变的相关性。对于预测问题,这就意味这需要找到一种数据表达,使得在该数据表达之上的最佳分类器在不同的环境中都相同。可按如下定义方式 :

定义 3 考虑一种数据表达 Φ:XH\Phi: \mathcal{X} \rightarrow \mathcal{H},如果有一个分类函数 w:HYw: \mathcal{H} \rightarrow \mathcal{Y} 适用于所有环境,则可导出的跨环境 E\mathcal{E} 的不变预测器 wΦw \circ \Phi,也即对于任意的 eEe \in \mathcal{E},都有 wargminwˉ:HYRe(wˉΦ)w \in \arg \min _{\bar{w}: \mathcal{H} \rightarrow \mathcal{Y}} R^{e}(\bar{w} \circ \Phi)

  为什么上述定义等价于与目标变量的相关性稳定的学习特征?对于损失函数如均方误差和交叉熵,最优的分类器可以写为条件期望。一种数据表达 Φ\Phi 可以产生的跨环境不变预测当且仅当对于 Φ(Xe)\Phi(X^e) 的所有焦点 h 处,对于任意的 e,eEe,e'\in\mathcal{E},都有 E[YeΦ(Xe)=h]=E[YeΦ(Xe)=h]\mathbb{E}\left[Y^{e} \mid \Phi\left(X^{e}\right)=h\right]=\mathbb{E}\left[Y^{e^{\prime}} \mid \Phi\left(X^{e^{\prime}}\right)=h\right]

  我们认为不变性的概念与科学中常用的归纳法是相抵触的。实际上,一些科学发现都可以追溯到发现一些不同的但潜在的相关现象,一旦用正确的变量描述,它们似乎遵循相同精确的物理定律。严格遵守这些规则表明它们在更广泛的条件下仍有效,如果牛顿的苹果和星球遵循相同方程,那么引力就是一件事。

  为了从经验数据中发现这些不变性,我们引入了 IRM 方法,不仅具有好的预测结果,还是跨环境 Etr\mathcal{E}_{\mathrm{tr}} 的不变预测器。从数学上,可转为为如下优化问题 (IRM) :

minΦ:XHw:HYeEtrRe(wΦ) subject to wargminwˉ:HYRe(wˉΦ), for all eEtr\begin{array}{ll}\min _{\Phi: \mathcal{X} \rightarrow \mathcal{H} \atop w: \mathcal{H} \rightarrow \mathcal{Y}} & \sum_{e \in \mathcal{E}_{\mathrm{tr}}} R^{e}(w \circ \Phi) \\\text { subject to } & w \in \underset{\bar{w}: \mathcal{H} \rightarrow \mathcal{Y}}{\arg \min } R^{e}(\bar{w} \circ \Phi), \text { for all } e \in \mathcal{E}_{\mathrm{tr}}\end{array}

  这是一个有挑战性的两级优化问题,我们将其转化为另一个版本 (IRMv1) :

minΦ:XYeEtrRe(Φ)+λww=1.0Re(wΦ)2\min _{\Phi: \mathcal{X} \rightarrow \mathcal{Y}} \sum_{e \in \mathcal{E}_{\mathrm{tr}}} R^{e}(\Phi)+\lambda \cdot\left\|\nabla_{w \mid w=1.0} R^{e}(w \cdot \Phi)\right\|^{2}

  其中 Φ\Phi 是整个不变预测器,w=1.0w = 1.0 是一个标量和一个固定的虚拟分类器,梯度形式惩罚是用来衡量每个环境 e 中虚拟分类器的最优性,λ[0,)\lambda\in [0,\infty)是预测能力 (ERM) 和预测 1Φ(x)1\cdot \Phi(x) 不变性的平衡调节参数。

3.1 从 (IRM) 到 (IRMv1)

3.1.1 将约束作为惩罚项

  我们将 (IRM) 中的硬性约束转化为如下的惩罚性损失 :

LIRM(Φ,w)=eEtrRe(wΦ)+λD(w,Φ,e)(1)L_{\mathrm{IRM}}(\Phi, w)=\sum_{e \in \mathcal{E}_{\mathrm{tr}}} R^{e}(w \circ \Phi)+\lambda \cdot \mathbb{D}(w, \Phi, e)\tag{1}

  其中函数 D(w,Φ,e)\mathbb{D}(w, \Phi, e) 表示了 ww 使得 Re(wΦ)R^{e}(w \circ \Phi) 达到最小化的程度,λ\lambda性的超参数。在实际应用中,我们希望 D(w,Φ,e)\mathbb{D}(w, \Phi, e) 关于 Φ\Phiww 是可微的。

3.1.2 对于线性分类器 ww 选择合适的惩罚项 D\mathbb{D}

  下面我们考虑 ww 为线性分类器这一特殊情况。当给定数据表达 Φ\Phi,我们可以由wΦeargminwˉRe(wˉΦ)w_{\Phi}^{e} \in \arg \min _{\bar{w}} R^{e}(\bar{w} \circ \Phi) 写出 :

minwΦeΦ(Xe)Ye2wΦe=EXe[Φ(Xe)Φ(Xe)]1EXe,Ye[Φ(Xe)Ye](2)\min\lVert w_{\Phi}^{e}\circ\Phi(X^e)-Y^e\rVert^2\\ w_{\Phi}^{e}=\mathbb{E}_{X^{e}}\left[\Phi\left(X^{e}\right) \Phi\left(X^{e}\right)^{\top}\right]^{-1} \mathbb{E}_{X^{e}, Y^{e}}\left[\Phi\left(X^{e}\right) Y^{e}\right]\tag{2}

图 1.在我们的示例 1 中,不同的不变性度量导致不同的优化场景。测量最优分类器 Ddist\mathbb{D}_{dist} 之间距离的朴素方法会导致不连续的惩罚 (实心蓝色未正则化,橙色虚线正则化)。相比之下,惩罚 Dlin\mathbb{D}_{lin} 没有表现出这些问题。

  且我们希望这两个线性分类器的差异越小越好,即 Ddist (w,Φ,e)=wwΦe2\mathbb{D}_{\text {dist }}(w, \Phi, e)=\left\|w-w_{\Phi}^{e}\right\|^{2}。我们将该方法用到3.1中的实例中,令 Φ(x)=xDiag([1,c])\Phi (x)=x \cdot Diag([1,c])w=[1,0]w=[1,0],则 c 控制了这个数据表达多大程度上依赖 X2X_2。我们做出不变性损失随 c 的变化图见图 1,发现 Ddist \mathbb{D}_{\text {dist }}c=0c=0 处是不连续的,而当 c 趋于 0 而不等于 0 时,利用最小二乘法计算 wΦew_{\Phi}^{e} 的第二个量将趋于无穷,因此出现了图 1 中蓝线的情况。图 1 中黄线表明在最小二乘中添加强的正则化不能解决这一问题。

Ddist (w,Φ,e)=wwΦe2=wEXe[Φ(Xe)Φ(Xe)]1EXe,Ye[Φ(Xe)Ye]2(3)\mathbb{D}_{\text {dist }}(w, \Phi, e)=\left\|w-w_{\Phi}^{e}\right\|^{2}=\\ \lVert w-\mathbb{E}_{X^{e}}\left[\Phi\left(X^{e}\right) \Phi\left(X^{e}\right)^{\top}\right]^{-1} \mathbb{E}_{X^{e}, Y^{e}}\left[\Phi\left(X^{e}\right) Y^{e}\right]\rVert^2\tag{3}

  为了解决这些问题,我们将最小二乘求 wΦew_{\Phi}^{e} 中的矩阵求逆去除,并按如下方式计算不变性损失 :

Dlin=(EXe[Φ(Xe)Φ(Xe)])2Ddist (w,Φ,e)=EXe[Φ(Xe)Φ(Xe)]wEXe,Ye[Φ(Xe)Ye]2(4)\mathbb{D}_{\operatorname{lin}}=\left(\mathbb{E}_{X^{e}}\left[\Phi\left(X^{e}\right) \Phi\left(X^{e}\right)^{\top}\right]\right)^2\mathbb{D}_{\text {dist }}(w, \Phi, e) \\ =\left\|\mathbb{E}_{X^{e}}\left[\Phi\left(X^{e}\right) \Phi\left(X^{e}\right)^{\top}\right] w-\mathbb{E}_{X^{e}, Y^{e}}\left[\Phi\left(X^{e}\right) Y^{e}\right]\right\|^{2}\tag{4}

  按照这种方式,得到图 1 绿线所示的情况。可见 Dlin\mathbb{D}_{\operatorname{lin}} 是平滑的 (它是 Φ\Phiw w 的多项式函数)。并且,当且仅当 wΦeargminwˉRe(wˉΦ)w_{\Phi}^{e} \in \arg \min _{\bar{w}} R^{e}(\bar{w} \circ \Phi) 时,Dlin(w,Φ,e)=0\mathbb{D}_{\operatorname{lin}}(w, \Phi, e)=0

3.1.3 固定线性分类器 ww

  即使在使用 Dlin\mathbb{D}_{\operatorname{lin}} 最小化 (Φ,w)(\Phi,w) 时,我们也会遇到一个问题。当考虑一对 (γΦ,1γw)(\gamma\Phi,\frac{1}{\gamma}w) 时,通过让 γ\gamma 趋于零,可以让 Dlin\mathbb{D}_{\operatorname{lin}} 趋于零而不影响 ERM 项。这个问题的出现是因为 (1) 严重过度参数化。我们通过 Dlin\mathbb{D}_{\operatorname{lin}} 最小化选择出的 (Φ,w)(\Phi , w) 是不唯一的,实际上对于可逆映射 Ψ\Psi,我们可以重写不变预测器为 :

wΦ=(wΨ1)w~(ΨΦ)Φ~w \circ \Phi=\underbrace{\left(w \circ \Psi^{-1}\right)}_{\tilde{w}} \circ \underbrace{(\Psi \circ \Phi)}_{\tilde{\Phi}}

  这意味着我们可以任意选择非零 w~\tilde{w} 作为不变预测器。因此,我们可以将搜索限制在给定 w~\tilde{w} 的所有环境最优分类的数据表达上。即 :

LIRM,w=wˉ(Φ)=eEtrRe(w~Φ)+λDlin(w~,Φ,e)(5)L_{\mathrm{IRM}, w=\bar{w}}(\Phi)=\sum_{e \in \mathcal{E}_{\mathrm{tr}}} R^{e}(\tilde{w} \circ \Phi)+\lambda \cdot \mathbb{D}_{\mathrm{lin}}(\tilde{w}, \Phi, e)\tag{5}

  当 λ\lambda \to \infty 时,对于线性 w~\tilde{ w},上式的解 (Φλ,w~)\left(\Phi_{\lambda}^{*}, \tilde{w}\right) 将趋于 (IRM) 的解 (Φ,w~)\left(\Phi^{*}, \tilde{w}\right)

3.1.4 固定分类器 w~\tilde{ w} 也可满足监视不变性

  前文我们提出 w~=(1,0,,0)\tilde{w}=(1,0, \ldots, 0) 是一个有效的分类器选择,这种情况下只有一部分的数据起作用。我们通过给出线性不变预测器的完整特征来说明这个悖论。下面的理论中的矩阵 ΦRp×d\Phi \in \mathbb{R}^{p \times d},为数据特征函数,向量 wRpw \in \mathbb{R}^{p} 为最优分类器,v=Φwv=\Phi^{\top} w 为预测向量 wΦw \circ \Phi

定理 4 对于所有 eEe \in \mathcal{E},令 Re:RdRR^{e}: \mathbb{R}^{d} \rightarrow \mathbb{R} 为凸可微损失函数。一个向量 vRdv \in \mathbb{R}^{d} 可以写为 v=Φwv=\Phi^{\top} w,其中 ww 对于所有环境 e,使得 Re(wΦ)R^{e}(w \circ \Phi) 同时达到最小,当且仅当对于所有环境 e,vRe(v)=0v^{\top} \nabla R^{e}(v)=0。所以,任何线性不变预测器可以被分解为不同秩的线性表达。具体地说,我们可以将搜索限制在 ΦR1×d\Phi\in\mathbb{R}^{1\times d} 的矩阵上,并使 w~R1\tilde{w}\in\mathbb{R}^1 的为固定标量 1.0。则可以将 (5) 转化为 :

LIRM,w=1.0(Φ)=eEtrRe(Φ)+λDlin(1.0,Φ,e)(6)L_{\mathrm{IRM}, w=1.0}\left(\Phi^{\top}\right)=\sum_{e \in \mathcal{E}_{\mathrm{tr}}} R^{e}\left(\Phi^{\top}\right)+\lambda \cdot \mathbb{D}_{\operatorname{lin}}\left(1.0, \Phi^{\top}, e\right)\tag{6}

  后文将证明,不管我们是否限制 IRM 搜索秩为 1 的 Φ\Phi ^\top,这种形式的分解将会引入高秩的数据表达矩阵,且是分布外泛化的关键。

3.1.5 推广到一般损失和多元输出

  章节 3.1.4通过加入不变性损失和均方误差得到最终的 IRMv1 模型,可以写出一般的风险方程 D(1.0,Φ,e)=ww=1.0Re(wΦ)2\mathbb{D}(1.0, \Phi, e)=\left\|\nabla_{w \mid w=1.0} R^{e}(w \cdot \Phi)\right\|^{2},其中 Φ\Phi 是一种可能的非线性数据表达。这种表达在任何损失下都最优匹配于常值分类器 w=1.0w= 1.0。如果 Φ\Phi 返回的目标空间 Y\mathcal{Y} 具有多个输出,我们将它们全部乘以标量分类器 w=1.0w = 1.0

  补充推导

Re(wΦ)=12(Φ(X)wY)(Φ(X)wY)ww=1.0Re(wΦ)=Φ(X)(Φ(X)wY)=Φ(X)Φ(X)wΦ(X)YD(1.0,Φ,e)=ww=1.0Re(wΦ)2=Φ(X)Φ(X)wΦ(X)Y2R^{e}(w \cdot \Phi)=\frac{1}{2}(\Phi(X)^{\top}w-Y)^{\top}(\Phi(X)^{\top}w-Y)\\ \nabla_{w|w=1.0} R^{e}(w \cdot \Phi)=\Phi(X)\cdot(\Phi(X)^{\top}w-Y)=\Phi(X)\Phi(X)^{\top}w-\Phi(X)Y\\ \mathbb{D}(1.0, \Phi, e)=\left\|\nabla_{w \mid w=1.0} R^{e}(w \cdot \Phi)\right\|^{2}=\lVert\Phi(X)\Phi(X)^{\top}w-\Phi(X)Y\rVert^2

图 2.不变线性预测量 v=ΦTwv=\Phi^Tw 的解与表示正交条件 vRe(v)=0v^{\top}\nabla R^e(v)=0 的椭球的交点重合。

3.2 执行细节

  当使用小批量梯度下降估计目标 (IRMv1) 时,可以得到平方估计范数的无偏估计 :

k=1b[ww=1.0(wΦ(Xke,i),Yke,i)ww=1.0(wΦ(Xke,j),Yke,j)]\sum_{k=1}^{b}\left[\nabla_{w \mid w=1.0} \ell\left(w \cdot \Phi\left(X_{k}^{e, i}\right), Y_{k}^{e, i}\right) \cdot \nabla_{w \mid w=1.0} \ell\left(w \cdot \Phi\left(X_{k}^{e, j}\right), Y_{k}^{e, j}\right)\right]

  其中 (Xe,i,Ye,i)(X^{e,i},Y^{e,i})(Xe,j,Ye,j)(X^{e,j},Y^{e,j}) 是环境 e 中的两个大小为 b 的随机小批量样本,\ell 为损失函数,PyTorch 例子见附件 D。

3.3 关于非线性不变 ww

  假设不变最优分类器 w 是线性的有多严格?一种说法是只要给予足够灵活的数据表达 Φ\Phi,就可以将不变预测器写为 1.0Φ1.0 \cdot \Phi。然而,强制执行线性不变性可能使得非不变预测惩罚 Dlin\mathbb{D}_{\mathrm{lin}} 等于 0。例如,空数据表达 Φ0(Xe)=0\Phi_0(X^e)=0 允许任何 w 为最优值。但是,当 E[Ye]0\mathbb{E}[Y^e]\ne 0 时,这样产生的预测器 wΦw \circ \Phi 不是不变的。ERM 项会丢弃这种无效的预测器。通常,最小化 ERM 项 Re(w~Φ)R^e(\tilde{w}\circ\Phi) 将驱动 Φ\Phi 以至于将 w~\tilde{w} 在所有预测器中达到最优,尽管 w~\tilde{w}是线性的。

  针对这个研究,我们也为未来的的研究提出了几个问题。是否存在不会被 ERM 和 IRM 丢弃的非不变预测器?如果将 w 放宽到可从非线性中选取将有什么好处?我们如何构造非线性不变量不变性的惩罚函数 D\mathbb{D}

4 不变性,因果性和泛化

  新提出的 IRM 方法使得在训练环境 Etr\mathcal{E}_{tr} 中具有更低的误差和不变特性。什么时候这些条件可以将不变性推广到所有环境中呢?更重要的时,什么时候这些条件可以使得在全部环境 Eall\mathcal{E}_{all} 中具有更低的误差,并导致分布外的泛化呢?并且在一个更基础的水平,统计不变性和分布外的泛化如何与因果理论中的概念相关?

  到目前为止,我们已经忽略了不同的环境应该如何关联以启用分布外泛化。这个问题的答案源于因果关系理论。我们首先假设来自所有环境的数据共享相同的底层结构方程模型 (SEM) :

定义 5 控制生成向量 X=(X1,,Xd)X=\left(X_{1}, \ldots, X_{d}\right) 的结构方程模型 C:=(S,N)\mathcal{C}:=(\mathcal{S}, N) 是一组结构方程 :

Si:Xifi(Pa(Xi),Ni)\mathcal{S}_{i}: X_{i} \leftarrow f_{i}\left(\operatorname{Pa}\left(X_{i}\right), N_{i}\right)

  其中 Pa(Xi){X1,,Xd}\{Xi}\mathrm{Pa}\left(X_{i}\right) \subseteq\left\{X_{1}, \ldots, X_{d}\right\} \backslash\left\{X_{i}\right\} 被称为 XiX_i 的双亲, NiN_i 是独立于噪声的随机变量。如果 XiPa(Xj)X_i\in Pa(X_j),可记为“ XiX_i 导致 XjX_j ”。我们可以据此来绘制因果图,每个 XiX_i 看作节点,如果 XiPa(Xj)X_i\in Pa(X_j) ,则就有从 XiX_iXjX_j 的一条边。我们假设该图是无环的。

  根据因果图的拓扑顺序,运行结构方程 (SEM)(SEM) C\mathcal{C},我们可以从观测分布 P(X)P(X) 的得到一些样本。同样,我们还可以以不同的方式操纵 (干预) 一个唯一的SEM,以 e 为指标,来得到不同但相关的 SEMsSEMs Ce\mathcal{C}^e

定义 6 考虑一个 SEMSEM C=(S,N)\mathcal{C}=(S,N)。用干预 e 作用到 C\mathcal{C}上 (包括替换一个或几个方程) 以得到干预 SEMSEM Ce=(Se,Ne)\mathcal{C}^e=(S^e,N^e),结构方程为 :

Sie:Xiefie(Pae(Xie),Nie)S_i^e:X_i^e\gets f_i^e(P_a^e(X_i^e),N_i^e)

  若 SiSieS_i\ne S_i^e 或者 NiNieN_i \ne N_i^e,则变量 XeX^e 是一种干预。

  类似的,通过运行干预 SEMSEM Ce\mathcal{C}^e 的结构方程,我们可以从干预分布 P(Xe)P(X^e) 中得到一些样本。例如我们可以考虑在例 1中干预 X2X_2 ,控制它为趋于 0 的常数,因此将 X2X_2 的结构方程替换为 X2e0X_2^e\gets 0。每个干预 e 都产生了一个干预分布为 P(Xe,Ye)P(X^e,Y^e) 的新环境 e。有效的干预 e 不会损坏太多的目标变量 Y 的信息,从而形成了环境集 Eall\mathcal{E}_{all}

  先前的工作考虑的是有效的干预不会改变Y的结构方程,因为对方程的任意干预都不可能预测。在这个工作中,我们也允许改变Y的噪声,因为在真实问题中会出现变化的噪声水平,这些并不会影响最优的预测规则。我们将其形式化如下 :

定义 7 考虑一个 SEMSEM C\mathcal{C} 控制随机向量 (X1,...,Xd,Y)(X_1,...,X_d,Y),以及基于 X 预测 Y 的学习目标。那么,所有的环境集合 Eall(C)\mathcal{E}_{all}(\mathcal{C}) 由干预产生的所有干预分布 P(Xe,Ye)P(X^e,Y^e) 得到。只要 (i)因果图是无环的,(ii) E[YePa(Y)]=E[YPa(Y)]\mathbb{E}\left[Y^{e} \mid \mathrm{Pa}(Y)\right]=\mathbb{E}[Y \mid \mathrm{Pa}(Y)],(iii) V[YePa(Y)]\mathbb{V}\left[Y^{e} \mid \operatorname{Pa}(Y)\right] 保持有限方差,则该干预 eEall(C)e\in \mathcal{E}_{all}(\mathcal{C}) 是有效的。

  如果在定义 ROODR^{OOD} 中考虑环境特定的基线,条件(iii)可以去除,与哪些出现在鲁棒性学习目标 RrobR^{rob} 相似。我们留下一些分布外泛化的其它量化作为以后的工作。

  先前定义了因果性和不变性之间建立的基础联系。另外,可以证明一个预测 v:XYv : \mathcal{X}\to \mathcal{Y} 是跨环境 Eall(C)\mathcal{E}_{all}(\mathcal{C}) 的不变预测,当且仅当它能达到最佳的 ROODR^{OOD} ,当且仅当它只使用Y的直接因果双亲来预测,也即, v(x)=ENY[fY(Pa(Y),NY)]v(x)=\mathbb{E}_{N_{Y}}\left[f_{Y}\left(\mathrm{Pa}(Y), N_{Y}\right)\right]。本节的其它部分将根据这些思想去展示如何利用跨环境的不变性实现所有环境中的分布外的泛化。

4.1 IRM 的泛化理论

  IRM 的目的就是建立一种可以产生域泛化的预测,也即,实现在整个环境 Eall\mathcal{E} _{all} 中具有更低的误差。为此,IRM 致力于在环境 Etr\mathcal{E}_{tr} 中同时减少误差以及保证不变性。这两者之间的桥梁由如下两步实现 : 第一步,可以证明 Etr\mathcal{E}_{tr} 环境中更低的误差和不变性将导致 Eall\mathcal{E} _{all} 中更低的误差。这是因为,一旦估算出在环境 Eall\mathcal{E}_{all} 中数据表达 Φ\Phi 产生的不变预测 wΦw \circ \PhiwΦw \circ \Phi 的误差将控制在标准误差界中。第二步,我们测试其余条件使得在环境 Eall\mathcal{E}_{all} 中具有更低的误差,即在什么条件下,训练环境 Etr\mathcal{E}_{tr} 中的不变性意味着所有环境 Eall\mathcal{E}_{all} 中的不变性?

  对于线性 IRM,我们回答这个问题的起点是不变因果预测理论 (ICP) 。这里,作者证明了 ICP 恢复目标不变性只要数据 (i) 是高斯分布的,(ii) 满足线性的 SEM,(iii) 从特定类型的干预中得到。定理 9 表明即使上述三个假设都不成立,IRM 也能学到这种不变性。特别是我们允许非高斯数据,处理作为具有稳定和虚假相关性的变量的线性转换产生的观察结果,不需要特定类型的干预或因果图的存在。

  定理的设定如下。 YeY^e 有一个不变相关性变量 Z1eZ_1^e ,它是一个未观察的潜在变量,具有线性关系为 Ye=Z1eγ+ϵeY^e=Z_1^e\cdot \gamma+\epsilon^eϵe\epsilon^e 独立于 Z1eZ_1^e。我们能观测到的是 XeX^e,它是 Z1eZ^e_1 和另一个与 Z1eZ^e_1ϵe\epsilon^e 任意相关的变量 Z2eZ^e_2 的干扰组合。简单的使用 XeX^e 回归将不计后果的利用了 Z2eZ_2^e (因为它给出了关于 ϵe\epsilon^eYeY^e 额外的虚假的信息)。为了实现分布外的泛化,数据表达必须丢弃 Z2eZ_2^e 且保留 Z1eZ_1^e

  在展示定理 9 之前,我们需要先做一些假设。为了学习有用的不变性,必须要求训练环境具有一定程度的多样性。一方面,从大数据集中随机抽取两个子集样本并不会导致环境的多样性,因为这两个子集服从相同的分布。另一方面,以任意变量为条件将大数据集分割可以产生多样性的环境,但是可能会引入虚假相关性且破坏我们需要的不变性。因此,我们需要包含足够多样性且满足基本不变性的训练环境。我们将这种多样性需求形式化为需要环境处于 linear general position。

假设 8 训练环境 Etr\mathcal{E}_{tr} 在 linear general position 的程度为r, Etr>dr+dr|\mathcal{E}_{tr}|>d-r+\frac{d}{r}rNr\in \mathbb{N},且对于所有的非零 xRdx\in \mathbb{R}^d :

dim(span({EXe[XeXe]xEXe,ϵe[Xeϵe]}eEtr))>dr\operatorname{dim}\left(\operatorname{span}\left(\left\{\mathbb{E}_{X^{e}}\left[X^{e} X^{e \top}\right] x-\mathbb{E}_{X^{e}, \epsilon^{e}}\left[X^{e} \epsilon^{e}\right]\right\}_{e \in \mathcal{E}_{t r}}\right)\right)>d-r

  直观上,这种 linear general position 的假设限制了训练环境共线性的程度。每个处在 linear general position 的新环境都将其不变解空间减少一个自由度。幸运的是,理论 10 表明不满足一个 linear general position 的叉积 EXe[XeXe]E_{X^e}[X^e {X^e}^\top] 集合为0。使用这种 linear general position 的假设,我们通过 IRM 学习的不变性可以从训练环境转化到全部环境。

  下面这个定理表明,如果在 Etr\mathcal{E}_{tr} 中找到一个秩为 r 的数据表达 Φ\Phi 导出的不变预测 wΦw \circ \Phi,且 Etr\mathcal{E}_{tr} 在 linear general position 的程度为r,那么 wΦw \circ \Phi 将是整个环境 Eall\mathcal{E}_{all} 中的不变预测。

定理 9 假设

Ye=Z1eγ+ϵe,Z1eϵe,E[ϵe]=0Xe=S(Z1e,Z2e)Y^{e}=Z_{1}^{e} \cdot \gamma+\epsilon^{e}, \quad Z_{1}^{e} \epsilon^{e}, \quad \mathbb{E}\left[\epsilon^{e}\right]=0\\ X^{e}=S\left(Z_{1}^{e}, Z_{2}^{e}\right)

  其中 γRc\gamma \in \mathbb{R}^cZ1eZ_1^eRc\mathbb{R}^c 中取值, Z2eZ_2^eRq\mathbb{R}^q 中取值,且 SRd×(c+q)S\in \mathbb{R}^{d\times (c+q)}。假设 SSZ1Z_1 分量是可逆的 : 那么存在 S~Rc×d\tilde{S}\in\mathbb{R}^{c\times d} 使得 S~(S(z1,z2))=z1\tilde{S}(S(z_1,z_2))=z_1。令 ΦRd×d\Phi\in \mathbb{R}^{d\times d} 的秩 r>0r>0。那么,至少 dr+drd-r+\frac{d}{r} 训练环境在 linear general position 中的程度为 r,我们有

wΦ=Φ(X)wΦEXe[XeXe]Φw=ΦEXe,Ye[XeYe]w\circ \Phi=\Phi(X)^{\top}w \\ \Phi\mathbb{E}_{X^{e}}\left[X^{e} X^{e^{\top}}\right] \Phi^{\top} w=\Phi \mathbb{E}_{X^{e}, Y^{e}}\left[X^{e} Y^{e}\right]

  对所有的 eEtre\in \mathcal{E}_{tr} 成立,当且仅当 Φ\Phi 导出的 Φw\Phi^{\top} w 是所有环境中的不变量。

  这个假设是线性的,中心误差,且噪声 ϵe\epsilon^e 与因果变量 Z1Z_1 是独立的,意味着不变性 E[YeZ1e=z1]=z1γ\mathbb{E}[Y^e|Z_1^e=z_1]=z_1\cdot \gamma 。在ICP中,我们允许在 ϵe\epsilon ^e 和非因果变量 Z2eZ_2^e 间的相关性,这导致 ERM 吸收了虚假相关性 (在例 1中, S=IS=IZ2e=X2eZ_2^e=X_2^e )。

  另外,我们的结果包含一些新颖之处。第一,我们并不假设数据是高斯分布的,这个存在的因果图或训练环境是由特定的干扰类型引发的。第二,结果可以扩展到 “加扰设置”,即 SIS\ne I 。这些情况中的因果关系没有定义观测特征 XX,但是在IRM中需要对潜在变量 (Z1,Z2)(Z_1,Z_2) 进行恢复和过滤。第三,我们表明表达 Φ\Phi 具有更高的秩就需要生成更少的训练环境。这是很好的,因为更高秩的表达将破坏更少的学习问题的信息。

  我们以两个重要观测来结束本小节。第一,鲁棒性学习会在训练环境内得到概括,而IRM的不变性学习将获得向外推断能力。我们可以从例1观察到,使用两个训练环境,鲁棒性的学习在 σ[10,20]\sigma\in [10,20] 时表现很好,而 IRM 的学习对于所有 σ\sigma 都表现很好。最后,对于训练环境的协方差IRM是一个微分函数。因此,当数据近似服从一个不变模型,IRM 应返回一个近似不变的解,对于轻度模型的错误具有鲁棒性。这与基于阈值系统的常见因果发现方法相反。

4.2 非线性情况中环境的数量

  与线性情况相同,我们可以为 IRM 提供非线性机制的保障。也即,我们可以假设每个约束 ww=1.0Re(wΦ)=0\left\|\nabla_{w \mid w=1.0} R^{e}(w \cdot \Phi)\right\|=0 都会从可能的结果 Φ\Phi 中移除一个自由度。那么,对于一个充分多的各种训练环境,我们可以得到不变预测器。不幸的是,我们不能说明这种 “nonlinear general positon” 假设,也不能证明它能在所有环境中都可用,因为定理 10 只是针对的线性情况。我们将其作为未来的工作。

  一般而言,定理9是消极的,因为它要求训练环境的数量与表达矩阵 Φ\Phi 中参数的数量成线性比例关系。幸运的是,我们在实验中观察到,通常两个环境就可以充分恢复不变性了。我们相信这些问题中 E[YeΦ(Xe)]\mathbb{E}[Y^e|\Phi(X^e)] 不能从两个不同的环境 eee\ne e’ 中完全匹配,除非 Φ\Phi 提取的是因果不变性。在大的 ww 族中找 ww 不变性应该允许丢弃一些很少训练环境下的更多的不变性。总之,从很少的环境中学习到不变性,是朝着不变性理论迈进的很有前途的工作。

4.3 因果性 vs 不变性

  我们促进不变性作为因果性的主要特征。当然,我们不是这样做的先驱。为了预测一个干预的结果,我们依赖 (i) 我们干预的性质,(ii) 在干预后这些性质假定不变。Pearl's 在因果图上的 do-运算是一个框架,能告诉我们什么条件在干预后保持不变。Rubin's 可执行扮演同样的角色。它通常被描述为一种因果机制的自治,是一种特殊的干预后的不变性。大量的哲学著作研究了不变性和因果关系的联系,一些机器学习的作品中也提到了类似的工作。

  因果关系的不变性观点超越了一些因果图处理的某些难题。例如,理想气体方程 PV=nRTPV=nRT 或牛顿万有引力方程很难使用结构方程模型描述,但在实验条件下是不变性的杰出例子。当收集气体或天体数据时,这些定理的普遍性将表现为不变的相关性,这将得到一种跨环境的有效预测,以及科学理论的概念。

  另一个支持因果关系的不变性观点的动机是研究机器学习问题。例如,考虑图像分类任务。这里,观察到的变量是成百上千的像素。控制它们的因果图是什么?一个合理的假设是因果关系并不会发生在像素之间,而是发生在相机捕获的真实概念之间。在这些情况下,图片中的不变相关是真实世界中的因果关系的代理。为了发现这些不变相关,我们需要一些方法能够将观察到的像素分解为更接近因果机制的潜在变量,例如 IRM。在少数情况下,我们对控制所有变量的完整因果图感兴趣。而是,我们关注的通常是因果不变性能够提升在新的分布样本中的泛化性。

5 实验

6 个人感想

  感觉数学处理方面并没有完全看明白,希望后续有空在阅读后面论文的闲暇之时有空翻过来重新理解重新感悟。这篇文章中的 IRM 都还主要应用在线性情况下,非线性情况并无过多赘述。

7 代码与补充说明

  这是论文原文中给出的示意代码。

import torch 
from torch.autograd import grad 

def compute_penalty (losses , dummy_w ): 
	g1 = grad(losses[0::2].mean (), dummy_w , create_graph=True)[0] 
	g2 = grad(losses[1::2].mean (), dummy_w , create_graph=True)[0] 
	return (g1 * g2).sum () 

def example_1 (n =10000,d =2,env =1): 
	x = torch.randn(n, d) * env 
	y = x + torch.randn(n, d) * env 
	z = y + torch.randn(n, d)
	return torch.cat((x, z), 1), y.sum(1, keepdim = True) 

phi = torch.nn.Parameter(torch.ones(4, 1))
dummy_w = torch.nn.Parameter(torch.Tensor ([1.0])) 
opt = torch.optim.SGD([phi], lr = 1e-3) 
mse = torch.nn.MSELoss(reduction = " none")
environments = [example_1(env = 0.1), example_1(env = 1.0)]

for iteration in range (50000):
	error = 0
	penalty = 0
	for x_e , y_e in environments: 
		p = torch.randperm(len(x_e))
		error_e = mse(x_e[p]@ phi * dummy_w, y_e[p])
		penalty += compute _penalty(error_e, dummy_w) 
		error += error_e.mean ()

	opt.zero_grad ()
	(1e-5 * error + penalty).backward () 
	opt.step ()
	
	if iteration % 1000 == 0:
		print ( phi )

  这是网上的一种主流的 IRMv1 代码实现

class IRM(ERM):
    """Invariant Risk Minimization"""

    @staticmethod
    def _irm_penalty(logits, y):
        device = "cuda" if logits[0][0].is_cuda else "cpu"
        scale = torch.tensor(1.).to(device).requires_grad_()
        loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
        loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
        grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
        grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

  可以看到,两段代码中都涉及到一个所谓的 “虚假参数” ww (dummy_w/scale),在此我们重新回顾并且重写原文中的推导过程。

D(w,Φ,e)=EXe[Φ(Xe)Φ(Xe)]wEXe,Ye[Φ(Xe)Ye]2=ww=1.0Re(wΦ)2\mathbb{D}(w, \Phi, e)=\left\|\mathbb{E}_{X^{e}}\left[\Phi\left(X^{e}\right) \Phi\left(X^{e}\right)^{\top}\right] w-\mathbb{E}_{X^{e}, Y^{e}}\left[\Phi\left(X^{e}\right) Y^{e}\right]\right\|^{2}\\=\left\|\nabla_{w \mid w=1.0} R^{e}(w \cdot \Phi)\right\|^{2}

  我们设虚假参数 II11,则可进行转化。

Re(w,Φ)=Re(w,Φ,I)=12(Φ(X)wIY)(Φ(X)wIY)Re(w,Φ,I)I=(Φ(X)w)(Φ(X)wY)=w(Φ(X)Φ(X)wΦ(X)Y)1wIRe(w,Φ,I)=Φ(X)Φ(X)wΦ(X)Y\begin{aligned} R^e(w,\Phi)&=R^e(w,\Phi,I)=\frac{1}{2}(\Phi(X)^{\top}w\cdot I-Y)^{\top}(\Phi(X)^{\top}w\cdot I-Y)\\ \frac{\partial R^e(w,\Phi,I)}{\partial I}&=(\Phi(X)^{\top}w)^{\top}\cdot(\Phi(X)^{\top}w-Y)\\ &=w^{\top}\left(\Phi(X)\Phi(X)^{\top}w-\Phi(X)Y\right)\\ \frac{1}{|w|}\left\lVert\nabla_I R^e(w,\Phi,I)\right\rVert&=\left\lVert\Phi(X)\Phi(X)^{\top}w-\Phi(X)Y\right\rVert \end{aligned}

  即我们在训练的过程中,如果限制住 w|w| 的大小 (限制在一定范围内),可以认为将 λ\lambda 变为 wλ|w|\lambda,这样在易得性和数学性之间取得平衡。