一、预测死亡-心力衰竭患者模型建立
1.数据集简介
-
心血管疾病 (CVD) 是全球第一大死因,估计每年夺去 1790 万人的生命,占全球所有死亡人数的 31%。
-
心力衰竭是由 CVD 引起的常见事件,该数据集包含 12 个可用于预测心力衰竭死亡率的特征。
-
大多数心血管疾病可以通过使用全民策略解决烟草使用、不健康饮食和肥胖、缺乏身体活动和有害使用酒精等行为风险因素来预防。
-
患有心血管疾病或处于高心血管风险(由于存在一种或多种风险因素,如高血压、糖尿病、高脂血症或已经确定的疾病)的人需要早期检测和管理,其中机器学习模型可以提供很大帮助。
2.scikiti-survival库的简介
-
scikit-survival 是一个 基于scikit-learn构建的用于生存分析的 Python 模块。它允许在利用 scikit-learn 的强大功能的同时进行生存分析,例如,用于预处理或进行交叉验证。
-
生存分析(也称为事件发生时间或可靠性分析)的目标是在协变量和事件发生时间之间建立联系。生存分析与传统机器学习的不同之处在于,部分训练数据只能部分观察——它们被删减了。
-
例如,在临床研究中,通常会在特定时间段内监测患者,并记录在该特定时间段内发生的事件。如果患者经历了事件,则可以记录事件的确切时间——患者的记录未经审查。相反,右截尾记录指的是在研究期间保持无事件的患者,并且不知道研究结束后事件是否发生。因此,生存分析需要考虑此类数据集的这一独特特征的模型。
文档: [User Guide — scikit-survival 0.20.1 scikit-survival.readthedocs.io/en/latest/u… Guide — scikit-survival 0.20.1 scikit-survival.readthedocs.io/en/latest/u…)
3.超参数调优框架optuna库的简介
optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。
二、环境构设
from IPython.display import clear_output
!pip install scikit-survival
!pip install optuna
clear_output() # 清理很长的内容
三、数据处理
1.数据查看
import pandas as pd
data=pd.read_csv('data/data209679/heart_failure_clinical_records_dataset.csv')
data.info()
data.head()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 299 entries, 0 to 298
Data columns (total 13 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 age 299 non-null float64
1 anaemia 299 non-null int64
2 creatinine_phosphokinase 299 non-null int64
3 diabetes 299 non-null int64
4 ejection_fraction 299 non-null int64
5 high_blood_pressure 299 non-null int64
6 platelets 299 non-null float64
7 serum_creatinine 299 non-null float64
8 serum_sodium 299 non-null int64
9 sex 299 non-null int64
10 smoking 299 non-null int64
11 time 299 non-null int64
12 DEATH_EVENT 299 non-null int64
dtypes: float64(3), int64(10)
memory usage: 30.5 KB
| age | anaemia | creatinine_phosphokinase | diabetes | ejection_fraction | high_blood_pressure | platelets | serum_creatinine | serum_sodium | sex | smoking | time | DEATH_EVENT | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 75.0 | 0 | 582 | 0 | 20 | 1 | 265000.00 | 1.9 | 130 | 1 | 0 | 4 | 1 |
| 1 | 55.0 | 0 | 7861 | 0 | 38 | 0 | 263358.03 | 1.1 | 136 | 1 | 0 | 6 | 1 |
| 2 | 65.0 | 0 | 146 | 0 | 20 | 0 | 162000.00 | 1.3 | 129 | 1 | 1 | 7 | 1 |
| 3 | 50.0 | 1 | 111 | 0 | 20 | 0 | 210000.00 | 1.9 | 137 | 1 | 0 | 7 | 1 |
| 4 | 65.0 | 1 | 160 | 1 | 20 | 0 | 327000.00 | 2.7 | 116 | 0 | 0 | 8 | 1 |
-
生存类数据,样本量小,使用交叉验证方法
-
✔️构建预测模型则用scikit-survival文库,这里可以预测未发生死亡事件的人群的死亡时间(从随访起点算起)。
2.X,y构建
from sksurv.util import Surv
from sksurv.ensemble import RandomSurvivalForest
from sklearn.impute import SimpleImputer
data['DEATH_EVENT']=[True if x==1 else 0 for x in data['DEATH_EVENT']]
y=Surv.from_dataframe(event='DEATH_EVENT',time='time',data=data)
cat_cols=['anaemia','diabetes','high_blood_pressure','sex','smoking']
data[cat_cols]=data[cat_cols].astype('category')
X=data.drop(['DEATH_EVENT','time'],axis=1)
X.head()
| age | anaemia | creatinine_phosphokinase | diabetes | ejection_fraction | high_blood_pressure | platelets | serum_creatinine | serum_sodium | sex | smoking | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 75.0 | 0 | 582 | 0 | 20 | 1 | 265000.00 | 1.9 | 130 | 1 | 0 |
| 1 | 55.0 | 0 | 7861 | 0 | 38 | 0 | 263358.03 | 1.1 | 136 | 1 | 0 |
| 2 | 65.0 | 0 | 146 | 0 | 20 | 0 | 162000.00 | 1.3 | 129 | 1 | 1 |
| 3 | 50.0 | 1 | 111 | 0 | 20 | 0 | 210000.00 | 1.9 | 137 | 1 | 0 |
| 4 | 65.0 | 1 | 160 | 1 | 20 | 0 | 327000.00 | 2.7 | 116 | 0 | 0 |
四、模型构建和评价
1.超参数搜索
# pipe-line
from sklearn.pipeline import make_pipeline
from sksurv.ensemble import RandomSurvivalForest
from sklearn.preprocessing import RobustScaler,StandardScaler,MinMaxScaler,OneHotEncoder
from sklearn.model_selection import cross_val_score
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector as selector
import optuna
import numpy as np
def objective(trial):
n_estimators=trial.suggest_int('n_estimators',100,1000,10)
min_sample_split=trial.suggest_int('min_sample_split',1,29,2)
min_sample_leaf=trial.suggest_int('min_sample_leaf',1,29,2)
preprocessor=make_column_transformer((RobustScaler(),selector(dtype_include='number')))
rsf=make_pipeline(preprocessor, RandomSurvivalForest(n_estimators=n_estimators,
min_samples_split=10,
min_samples_leaf=15,
n_jobs=-1,
random_state=0))
scores=cross_val_score(rsf,X,y)
return np.mean(scores)
study=optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
study.best_params
[I 2023-04-17 00:42:44,270] Trial 95 finished with value: 0.7495800673166991 and parameters: {'n_estimators': 230, 'min_sample_split': 15, 'min_sample_leaf': 5}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:47,104] Trial 96 finished with value: 0.7488762391145961 and parameters: {'n_estimators': 280, 'min_sample_split': 13, 'min_sample_leaf': 9}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:49,221] Trial 97 finished with value: 0.7502194611898544 and parameters: {'n_estimators': 200, 'min_sample_split': 13, 'min_sample_leaf': 11}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:51,100] Trial 98 finished with value: 0.7536458218120432 and parameters: {'n_estimators': 160, 'min_sample_split': 21, 'min_sample_leaf': 1}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:52,665] Trial 99 finished with value: 0.7523008612365734 and parameters: {'n_estimators': 120, 'min_sample_split': 25, 'min_sample_leaf': 3}. Best is trial 37 with value: 0.7540914188972966.
2.模型训练
#best_model在后续预测中使用cindex=0.73
preprocessor=make_column_transformer((RobustScaler(),selector(dtype_include='number')))
rsf_best=make_pipeline(preprocessor, RandomSurvivalForest(n_estimators=170,
min_samples_split=15,
min_samples_leaf=25,
n_jobs=-1,
random_state=0))
rsf_best.fit(X,y)
import joblib
joblib.dump(rsf_best,'rsf_best.pkl')
['rsf_best.pkl']
3.模型预测
#限制累积风险为1,获得对应的时间。
va_times=np.arange(4,241)
data_pre=data[data['DEATH_EVENT']!=True].drop(['DEATH_EVENT','time'],axis=1)
chf_funcs = rsf_best.predict_cumulative_hazard_function(data_pre)#产生对所有的test的风险函数,只需传入时间参数即可获得累积风险
outcome_period=[]
for fn in chf_funcs:#
if fn(va_times[-1])<1:#在最后的预测时间内死亡全部累计概率不到0.6
time_value=999
else:
for time in va_times:
if fn(time)>1:
time_value=time#发生结局的最短时间
break
# print(time)
outcome_period.append(time_value)
outcome_predict=data_pre.copy()
outcome_predict['outcome_period']=outcome_period
result=outcome_predict[outcome_predict['outcome_period']!=999]['outcome_period']
4.保存结果
patient_id=result.index
patient_surv_month=result.values
for i,x in zip(patient_id,patient_surv_month):
print('{}号患者死亡的时间为{}个月时。'.format(i,x))
#这里的时间计算开始是从患者入组时间开始算起,不是当下日期。
20号患者死亡的时间为235个月时。
38号患者死亡的时间为198个月时。
89号患者死亡的时间为235个月时。
96号患者死亡的时间为235个月时。
98号患者死亡的时间为235个月时。
100号患者死亡的时间为235个月时。
102号患者死亡的时间为235个月时。
112号患者死亡的时间为235个月时。
117号患者死亡的时间为198个月时。
131号患者死亡的时间为198个月时。
137号患者死亡的时间为193个月时。
155号患者死亡的时间为235个月时。
157号患者死亡的时间为235个月时。
173号患者死亡的时间为235个月时。
190号患者死亡的时间为180个月时。
198号患者死亡的时间为235个月时。
203号患者死亡的时间为196个月时。
210号患者死亡的时间为235个月时。
223号患者死亡的时间为235个月时。
224号患者死亡的时间为235个月时。
226号患者死亡的时间为235个月时。
228号患者死亡的时间为196个月时。
229号患者死亡的时间为235个月时。
247号患者死亡的时间为180个月时。
281号患者死亡的时间为207个月时。
282号患者死亡的时间为198个月时。
6.预测个案
加载存储的模型,然后进行预测
def survival_time(model,patient):
chf_funcs=model.predict_cumulative_hazard_function(patient)
for fn in chf_funcs:#
if fn(va_times[-1])<1:#在最后的预测时间内死亡全部累计概率不到0.6
time_value=999
print('该患者在241个月内未预测到因疾病原因的死亡')
else:
for time in va_times:
if fn(time)>1:
time_value=time#发生结局的最短时间
break
print('该患者预测在{}月时因疾病原因死亡'.format(time))
#加载储存的模型
model=joblib.load('rsf_best.pkl')
#输入患者数据,我们这里加载了20号患者,可以看到和前面的批量预测是一致的。
patient=data_pre[data_pre.index==20]
print(patient)
#预测死亡时间
survival_time(model,patient)
age anaemia creatinine_phosphokinase diabetes ejection_fraction \
20 65.0 1 52 0 25
high_blood_pressure platelets serum_creatinine serum_sodium sex smoking
20 1 276000.0 1.3 137 0 0
该患者预测在235月时因疾病原因死亡
本文正在参加「金石计划」