前言
from sklearn.ensemble import GradientBoostingRegressor
GradientBoostingRegressor(alpha=0.9,
ccp_alpha=0.0,
criterion='friedman_mse',
init=None,
learning_rate=0.1,
loss='ls',
max_depth=3,
max_features=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
min_impurity_split=None,
min_samples_leaf=1,
min_samples_split=2,
min_weight_fraction_leaf=0.0,
n_estimators=100,
n_iter_no_change=None,
presort='deprecated',
random_state=None,
subsample=1.0,
tol=0.0001,
validation_fraction=0.1,
verbose=0,
warm_start=False)
loss:
字符串类型,默认值为 ‘ls’,实际可选的有:
‘ls’:此时损失函数为平方损失函数,使用最小二乘回归。
‘lad’:此时使用指数绝对值损失函数。
‘huber’:此时损失函数为上述两者的综合,即误差较小时,采用平方损失,在误差较大时,采用绝对值损失。
‘quantile’:分位数回归(分位数指的是百分之几),采用绝对值损失。
criterion:
字符串类型,默认值为 ‘friedman_mse’,是衡量回归效果的指标。可选的有:
‘friedman_mse’:改进型的均方误差
‘mse’:标准的均方误差
‘mae’:平均绝对误差
除了这两个参数之外,其他参数、属性、方法的含义与用法与GBDT分类器的参数基本一致。
代码实现
GitHub:
https://github.com/NLP-LOVE/ML-NLP/blob/master/Machine%20Learning/3.2%20GBDT/GBDT_demo.ipynb
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingRegressor
train_feature = np.genfromtxt("train_feat.txt",dtype=np.float32)
num_feature = len(train_feature[0])
train_feature = pd.DataFrame(train_feature)
train_label = train_feature.iloc[:, num_feature - 1]
train_feature = train_feature.iloc[:, 0:num_feature - 2]
test_feature = np.genfromtxt("test_feat.txt",dtype=np.float32)
num_feature = len(test_feature[0])
test_feature = pd.DataFrame(test_feature)
test_label = test_feature.iloc[:, num_feature - 1]
test_feature = test_feature.iloc[:, 0:num_feature - 2]
gbdt = GradientBoostingRegressor(loss = 'ls',
learning_rate = 0.1,
n_estimators = 100,
subsample = 1,
min_samples_split = 2,
min_samples_leaf = 1,
max_depth = 3,
init = None,
random_state = None,
max_features = None,
alpha = 0.9,
verbose = 0,
max_leaf_nodes = None,
warm_start = False )
gbdt.fit(train_feature, train_label)
pred = gbdt.predict(test_feature)
total_err = 0
for i in range(pred.shape[0]):
print('pred:', pred[i], ' label:', test_label[i])
print('均方误差:', np.sqrt(((pred - test_label) ** 2).mean()))
pred: 320.0008173984891 label: 320.0
pred: 360.99965033119537 label: 361.0
pred: 363.99928183902097 label: 364.0
pred: 336.0002344322584 label: 336.0
pred: 358.0000159974151 label: 358.0
均方误差: 0.0005218003748239915