生存分析:利用深度学习进行时间到事件的预测

1,005 阅读10分钟

再住院的实际应用

生存模型非常适合预测事件发生的时间。这些模型可用于各种用例,包括预测性维护(预测机器何时可能发生故障)、营销分析(预测客户流失)、患者监测(预测患者可能再次住院)、以及更多。

通过将机器学习与生存模型相结合,产生的模型可受益于前者的高预测能力,同时保留后者的框架和典型输出(如随时间变化的生存概率或危险曲线)。欲了解更多信息,请查看本系列的第一篇文章。

然而,在实践中,基于ML的生存模型仍然需要大量的特征工程,从而需要事先的商业知识和直觉来导致满意的结果。那么,为什么不使用深度学习模型来弥补这一差距呢?

目标

本文重点讨论如何将深度学习与生存分析框架相结合,以解决诸如预测病人(重新)住院的可能性等用例。

阅读本文后,你会明白:

  1. 如何利用深度学习进行生存分析?
  2. 生存分析中常见的深度学习模型有哪些,它们是如何工作的?
  3. 如何将这些模型具体应用到住院预测中?

本文是围绕生存分析的系列文章的第二部分。如果你对生存分析不熟悉,最好先阅读第一篇 这里.文章中描述的实验是通过使用库 scikit-survival, pycox plotly.你可以在这里找到这些代码 GitHub.

1.生存分析和深度学习:如何将它们结合起来?

1.1.问题陈述

让我们先来描述一下手头的问题。

我们对预测一个给定的病人在其健康状况的现有信息下再次住院的可能性很感兴趣。更具体地说,我们想在上次就诊后的不同时间点上估计这个概率。这样的估计对于监测病人的健康状况和减轻他们的复发风险至关重要。

这是一个典型的生存分析问题。数据由3个元素组成:

病人的基线数据,包括:

  • 人口统计学:年龄、性别、地点(农村或城市)
  • 患者病史:吸烟、酗酒、糖尿病、高血压等。
  • 实验室结果:血红蛋白、总淋巴细胞数、血小板、葡萄糖、尿素、肌酐等。
  • 关于源数据集的更多信息,请点击这里

一个时间t和一个事件指标δ∈{0;1}:

  • 如果事件发生在观察期内,t等于收集数据的时刻和观察到事件(即再入院)的时刻之间的时间,在这种情况下,δ=1。
  • 如果不是,t等于从收集数据的时刻到最后一次与病人接触(如研究结束)的时间。在这种情况下,δ=0。

image.png

图1 - 生存分析数据,作者的插图。注:病人A和C是删减的。

⚠️ 有了这样的描述,既然问题与回归任务如此相似,为什么还要使用生存分析方法?最初的论文对主要原因给出了一个相当好的解释:

"如果选择使用标准的回归方法,右删减的数据就成为一种缺失数据。它通常被删除或估算,这可能会给模型带来偏差。因此,为右删减数据建模需要特别注意,因此要使用生存模型"。资料来源[2]

1.2.DeepSurv

方法

让我们进入理论部分,对危险函数进行一下回顾。

"危险函数是指一个人在已经存活到时间t的情况下,不能再存活额外的无限长的时间δ的概率。

image.png

来源[2]

与Cox比例危害(CPH)模型类似,DeepSurv是基于危害函数是2个函数的乘积这一假设:

  • **基线危险函数:**λ_0(t)
  • 风险分数,r(x)=exp(h(x))。它模拟了在观察到的协变量的情况下,一个给定个体的危险函数是如何从基线变化的。

更多关于CPH模型的内容请参见本系列文章的第一篇

image.png

函数h(x)通常被称为对数风险函数。而这正是Deep Surv模型所要模拟的函数。

事实上,CPH模型假设h(x)是一个线性函数:h(x)= β . x。因此,拟合模型包括计算权重β以优化目标函数。然而,线性比例危害的假设在许多应用中并不成立。这证明需要一个更复杂的非线性模型,最好能够处理大量的数据。

架构

在这种情况下,DeepSurv模型如何能提供一个更好的选择?让我们从描述它开始。根据最初的论文,它是一个 "深度前馈神经网络,预测病人的协变量对其危险率的影响,参数是网络的权重θ"。[2]

它是如何工作的?

网络的输入是基线数据x。

隐层由全连接的非线性激活函数组成,然后是滤波。

最后一层是一个单节点,对隐藏特征进行线性组合。网络的输出被视为预测的对数风险函数。

来源于[2]

image.png

图2-DeepSurv架构,作者的插图,灵感来自来源[2]

由于这种架构,该模型非常灵活。超参数搜索技术通常被用来确定隐藏层的数量、每层的节点数量、辍学概率和其他设置。

要优化的目标函数呢?

  • CPH模型的训练是为了优化Cox部分似然。它包括为每个病人i在时间Ti计算事件发生的概率,考虑到所有在时间Ti仍处于风险的个体,然后将所有这些概率相乘。你可以在这里找到准确的数学公式[2]。
  • 同样,DeepSurv的目标函数是相同的部分似然的对数-负值,还有一个额外的部分,用于规范网络权重。[2]

代码样本

这是一个小代码片段,可以了解如何使用pycox库实现这种类型的模型。完整的代码可以在此处[6]库的笔记本示例中找到。

 in_features = x_train.shape[ 1 ] num_nodes = [ 32 , 32 ] out_features = 1 batch_norm = True dropout = 0.1 output_bias = False net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,                               dropout, output_bias=output_bias) model = CoxPH(net, tt.optim.Adam) batch_size = 256 epochs = 512 callbacks = [tt.callbacks .EarlyStopping()]详细 =True model.optimizer.set_lr( 0.01 ) log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,                 val_data=val, val_batch_size=batch_size) _ = model.compute_baseline_hazards() surv = model .predict_surv_df(x_test) ev = EvalSurv(surv, durations_test, events_test, censor_surv= 'km' ) ev.concordance_td()

1.3. 深击

方法

如果我们可以训练一个直接学习它们的深度神经网络,而不是对生存时间的分布做出强有力的假设,会怎样?

DeepHit模型就是这种情况。特别是,它比以前的方法带来了两个重大改进:。

  • 它不依赖于任何关于基础随机过程的假设。因此,该网络学会了对协变量和风险之间的关系随时间的演变进行建模。
  • 它可以通过多任务学习架构处理竞争性风险(例如,同时对重新住院和死亡的风险进行建模)。

架构

如本文所述[3],DeepHits遵循多任务学习模型的常见架构。它由两个主要部分组成:

  1. 一个共享子网络,模型从数据中学习对所有任务有用的一般表示。
  2. 特定任务的子网络,模型在这里学习更多特定任务的表征。

然而,DeepHit模型的架构在两个方面与典型的多任务学习模型不同:

  • 它包括初始协变量和特定任务子网络的输入之间的剩余连接。
  • 它只使用一个softmax输出层。正因为如此,该模型不学习竞争事件的边际分布,而是学习联合分布。

下面的数字显示了在两个任务上同时训练模型的情况。

DeepHit模型的输出是每个主题的一个向量y。它给出了在观察时间内的每个时间戳t,主体将经历事件k∈[1,2]的概率。

image.png

图3 - DeepHit架构,作者的插图,灵感来自于来源[4] 。

2.用例应用:这些模型在实践中的表现如何?

2.1.方法论

数据

数据集被分为三部分:训练集(60%的数据)、验证集(20%)和测试集(20%)。训练集和验证集用于在训练期间优化神经网络,测试集用于最终评估。

基准测试

深度学习模型的性能与包括CoxPH和基于ML的生存模型(梯度提升和SVM)在内的基准模型进行了比较。关于这些模型的更多信息可在本系列的第一篇文章中找到。

衡量标准

有两个指标被用来评估这些模型:

  • 一致性指数(C-index):它衡量的是模型根据个人风险评分提供可靠的生存时间排名的能力。它被计算为数据集中一致对的比例。
  • Brier评分:它是平均平方误差对右删减数据的一个随时间变化的扩展。换句话说,它代表观察到的生存状态和预测的生存概率之间的平均平方距离。

2.2.结果

从C-index来看,深度学习模型的性能大大优于基于ML的生存分析模型。此外,Deep Surval和Deep Hit模型的性能几乎没有差别。

image.png

图4 - 模型在训练集和测试集上的C指数

在Brier得分方面,Deep Surv模型从其他模型中脱颖而出。

  • 当检查Brier分数作为时间的函数的曲线时,Deep Surv模型的曲线比其他模型低,这反映了更好的准确性。

image.png

图5- 测试集上的Brier得分

  • 在考虑同一时间区间内的积分时,这一观察结果得到了证实。

image.png

图6 - 测试集上的综合布赖尔分数

请注意,Brier并不是为SVM计算的,因为这个分数只适用于能够估计生存函数的模型。

image.png

图7- 使用DeepSurv模型随机选择的患者的生存曲线

最后,深度学习模型和统计模型一样,可以用于生存分析。例如,在这里,我们可以看到随机选择的病人的生存曲线。这样的输出可以带来很多好处,特别是允许对风险最大的患者进行更有效的跟踪。

主要启示

✔️ 生存模型对于预测一个事件发生的时间非常有用。

✔️ 它们可以通过提供学习框架和技术以及有用的输出,如生存概率或随时间变化的危险曲线,帮助解决许多用例。

✔️ 在这种类型的用例中,它们甚至是不可或缺的,可以利用所有的数据,包括删减的观察值(例如,当事件在观察期间没有发生时)。

✔️ 基于ML的生存模型往往比统计模型表现更好(更多信息在这里)。然而,它们需要基于坚实的商业直觉的高质量特征工程来实现令人满意的结果。

✔️ 这就是深度学习可以弥补差距的地方。基于深度学习的生存模型,如DeepSurv或DeepHit,有可能以更少的努力获得更好的表现!

✔️ 尽管如此,这些模型也不是没有缺点的。它们需要大量的数据库进行训练,并需要对多个超参数进行微调。

参考文献

[1] Bollepalli, S.C.; Sahani, A.K.; Aslam, N.; Mohan, B.; Kulkarni, K.; Goyal, A.; Singh, B.; Singh, G.; Mittal, A.; Tandon, R.; Chhabra, S.T.; Wander, G.S.; Armoundas, A.A.一个优化的机器学习模型准确预测入院时心脏科的住院结果。诊断学2022,12,241。

[2] Katzman, J., Shaham, U., Bates, J., Cloninger, A., Jiang, T., & Kluger, Y. (2016)。DeepSurv:使用Cox Proportional Hazards深度神经网络的个性化治疗推荐系统,ArXiv

[3] Laura Löschmann、Daria Smorodina,用于生存分析深度学习,信息系统研讨会(WS19/20),2020年2月6日

[4] 李昌熙等。DeepHit:具有竞争风险的生存分析的深度学习方法。AAAI人工智能会议(2018)。

[5] 维基百科,比例风险模型

[6] Pycox库