Causual Inference in DNN(1)Learning Representations for Counterfactual Inference

108 阅读3分钟

Highlight

这个论文集合用来整理因果推断在深度学习中的使用。

第一篇论文叫 Learning Representations for Counterfactual Inference
发表于:ICML 2016
下载链接:proceedings.mlr.press/v48/johanss…

简述

随着智能领域的数据积累,观测性研究变得越来越重要。 作者提出了一种新的反事实推理算法框架,该框架带来了领域适应和表示学习的想法。该论文主要关注观测性研究中的反事实问题,例如,对于一个确定的糖尿病病人和两种给定的抗糖尿病药物A和B,哪一种药物治疗效果更好。该论文采用的方法是在反事实推理和领域适应之间建立联系, 在为具有不同干预的人群学习的表示分布之间强制执行相似性来引入一种正则化形式。

这篇文章提出了一种利用领域适应和深度神经网络表示学习的框架方法来进行反事实结果推理。主要的贡献在以下三方面:

  1. 公式化反事实推理问题为领域适应问题,更具体一点,转化为协变量转变问题。
  2. 利用深度神经网络表示,线性模型和变量选择来进行反事实推理。
  3. 利用reweighting samples的方法使treatment和control groups distribution balanced

这篇文章提出了反事实推理中经常出现的问题:

  1. 只能观测到事实结果(factual outcome),无法直接观测反事实。
  2. 反事实分布可能与事实分布不同,无法直接利用网络进行训练。

方案

对此,文章提出了以下解决方案:

  1. 使用训练集上的误差最小化和正则化来获得对事实表示的低误差预测。
  2. 通过计算与各自实验和对照组中最接近的观察结果接近的结果,获得反事实的低误差预测。
  3. 通过最小化差异距离来实现分布平衡,差异距离是领域适应的距离度量。

image.png 损失函数表示为: image.png 其中第一和第二项分别表示factual error和counterfactual error。第三项是对两种分布discrepancy distance的量化。另外,如果表征网络是神经网络,那么模型就叫做balancing neural network(BNN);否则直接用balancing reweight的方法就叫做balancing linear regression(BLR)。他们都是用的同一个框架balancing counterfactual regression。因为我们重点关注的是神经网络,因此下面重点介绍BNN的构造:

image.png

BNN网络为标准前向全连接层神经网络。上图中drd_r代表学习输入表征的隐藏层, d0d_0结合drd_r的输出和treatment t生成预测 h([ϕ(x),t])h([\phi(x),t]),然后和不同组中距离最小的couterfactual ground truth进行比较计算出loss。另一条线计算的是discrepancy disc(PϕF,PϕCF)disc(P_{\phi}^F,P_{\phi}^{CF})。训练就是要使这两个值尽可能的小。

实验

实验方面采用经典数据集Infant Health and Development Program(IHDP)。这个数据集具有来自真实随机实验的协变量,研究了高质量的儿童保育和家访对未来认知测试分数的影响。通过移除部分实验组,人为引入数据分布的不平衡。作者进行了100次重复的超参数选择和1000次重复实验,得到以下结果。需要说明的是,下图中ϵITE\epsilon_{ITE}指的是RMSE of the estimate individual treatment effect. ϵATE\epsilon_{ATE}指的是absolute error in estimated average treatment effect。数值越低越好。 image.png

参考

  1. zhuanlan.zhihu.com/p/425331915
  2. zhuanlan.zhihu.com/p/474401295
  3. proceedings.mlr.press/v48/johanss…