随机森林总结 - RandomForestRegressor(二)
根据菜菜的课程进行整理,方便记忆理解
代码位置如下:
RandomForestRegressor
class sklearn.ensemble.RandomForestRegressor (
n_estimators=’warn’,criterion=’mse’,max_depth=None,min_samples_split=2,min_samples_leaf=1, min_weight_fraction_leaf=0.0,max_features=’auto’,max_leaf_nodes=None,min_impurity_decrease=0.0,min_impurity_split=None,bootstrap=True,oob_score=False,n_jobs=None,random_state=None, verbose=0, warm_start=False)
重要参数
criterion
回归树衡量分枝质量的指标,支持的标准有三种:
- 输入
mse使用均方误差mean squared error(MSE),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准,这种方法通过使用叶子节点的均值来最小化L2损失 - 输入
friedman_mse使用费尔德曼均方误差,这种指标使用弗里德曼针对潜在分枝中的问题改进后的均方误差 - 输入
mae使用绝对平均误差MAE(mean absolute error),这种指标使用叶节点的中值来最小化L1损失
其中N是样本数量,i是每一个数据样本,fi是模型回归出的数值,yi是样本点i实际的数值标签。所以MSE的本质,其实是样本真实数据与回归结果的差异。在回归树中,MSE不只是我们的分枝质量衡量指标,也是我们最常用的衡量回归树回归质量的指标,当我们在使用交叉验证,或者其他方式获取回归树的结果时,我们往往选择均方误差作为我们的评估(在分类树中这个指标是score代表的预测准确率)。在回归中,我们追求的是,MSE越小越好。
然而,回归树的接口score返回的是R平方,并不是MSE。R平方被定义如下:
其中u是残差平方和(MSE * N),v是总平方和,N是样本数量,i是每一个数据样本,fi是模型回归出的数值,yi是样本点i实际的数值标签。y帽是真实数值标签的平均数。R平方可以为正为负(如果模型的残差平方和远远大于模型的总平方和,模型非常糟糕,R平方就会为负),而均方误差永远为正。
值得一提的是,虽然均方误差永远为正,但是sklearn当中使用均方误差作为评判标准时,却是计算”负均方误差“(neg_mean_squared_error) 。这是因为sklearn在计算模型评估指标的时候,会考虑指标本身的性质,均方误差本身是一种误差,所以被sklearn划分为模型的一种损失(loss),因此在sklearn当中,都以负数表示。真正的均方误差MSE的数值,其实就是neg_mean_squared_error去掉负号的数字。
重要属性和接口
- 随机森林回归并没有predict_proba这个接口,因为对于回归来说,并不存在一个样本要被分到某个类别的概率问题,因此没有predict_proba这个接口
随机森林回归用法
和决策树完全一致,除了多了参数n_estimators
import sklearn
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import cross_val_score
boston = load_boston()
rfr = RandomForestRegressor(n_estimators=100,random_state=0)
score = cross_val_score(rfr,boston.data,boston.target,cv=10,scoring="neg_mean_squared_error")
score
"""
array([-10.72900447, -5.36049859, -4.74614178, -20.84946337,
-12.23497347, -17.99274635, -6.8952756 , -93.78884428,
-29.80411702, -15.25776814])
"""
- 通过
metrics.SCORERS查看sklearn的打分机制,包括常见的accuracy,f1,neg_mean_squared_error等
sorted(sklearn.metrics.SCORERS.keys())
"""
['accuracy',
'adjusted_mutual_info_score',
'adjusted_rand_score',
'average_precision',
'balanced_accuracy',
'completeness_score',
'explained_variance',
'f1',
'f1_macro',
'f1_micro',
'f1_samples',
'f1_weighted',
'fowlkes_mallows_score',
'homogeneity_score',
'jaccard',
'jaccard_macro',
'jaccard_micro',
'jaccard_samples',
'jaccard_weighted',
'max_error',
'mutual_info_score',
'neg_brier_score',
'neg_log_loss',
'neg_mean_absolute_error',
'neg_mean_gamma_deviance',
'neg_mean_poisson_deviance',
'neg_mean_squared_error',
'neg_mean_squared_log_error',
'neg_median_absolute_error',
'neg_root_mean_squared_error',
'normalized_mutual_info_score',
'precision',
'precision_macro',
'precision_micro',
'precision_samples',
'precision_weighted',
'r2',
'recall',
'recall_macro',
'recall_micro',
'recall_samples',
'recall_weighted',
'roc_auc',
'roc_auc_ovo',
'roc_auc_ovo_weighted',
'roc_auc_ovr',
'roc_auc_ovr_weighted',
'v_measure_score']
"""