机器学习:sklearn中xgboost模块的XGBClassifier函数(分类)

参考网址

blog.csdn.net/levy_cui/ar…

前言

我们都知道,xgboost算法有2大类接口。
一个是陈天奇团队开发的原生接口xgboost。
另一个是sklearn中的xgboost接口。
本文,是讲解sklearn接口中的分类函数XGBClassifier。
复制代码

常规参数

  • booster

    • gbtree 树模型做为基分类器(默认),及弱学习器的类型,这里默认是cart分类回归决策树
    • gbliner 线性模型做为基分类器
  • silent

    • silent=0时,不输出中间过程(默认)
    • silent=1时,输出中间过程
  • nthread

    • nthread=-1时,使用全部CPU进行并行运算(默认)
    • nthread=1时,使用1个CPU进行运算。
  • scale_pos_weight

    • 正样本的权重,在二分类任务中,当正负样本比例失衡时,设置正样本的权重,模型效果更好。
    • 例如,当正负样本比例为1:10时,scale_pos_weight=10,这样设置之后,正负样本比例就一样了。

模型参数

  • n_estimatores

    • 含义:总共迭代的次数,即决策树的个数
    • 调参:这个参数要大一点好,在不断增大该参数的时候,模型效果会变好, 但是当增大到一定的值的时候,模型效果就不会有明显的效果了,还可能会变坏。
  • early_stopping_rounds

    • 含义:在验证集上,当连续n次迭代,分数没有提高后,提前终止训练。
    •     这里的n次就是early_stopping_rounds值。
      复制代码
    • 调参:防止overfitting。
  • max_depth

    • 含义:树的深度,默认值为6,典型值3-10。
    • 调参:值越大,越容易过拟合;值越小,越容易欠拟合。
  • min_child_weight

    • 含义:默认值为1,。
    • 调参:值越大,越容易欠拟合;值越小,越容易过拟合(值较大时,避免模型学习到局部的特殊样本)。
  • subsample

    • 含义:训练每棵树时,使用的数据占全部训练集的比例。默认值为1,典型值为0.5-1。
    • 调参:防止overfitting。
  • colsample_bytree

    • 含义:训练每棵树时,使用的特征占全部特征的比例。默认值为1,典型值为0.5-1。
    • 调参:防止overfitting。

学习任务参数

  • learning_rate

    • 含义:学习率,控制每次迭代更新权重时的步长,默认0.3。
    • 调参:值越小,训练越慢。
    • 典型值为0.01-0.2。
  • objective 目标函数

    • 回归任务

      • reg:linear (默认)
      • reg:logistic
    • 二分类

      • binary:logistic   返回概率 
      • binary:logitraw  返回类别
    • 多分类

      • multi:softmax  num_class=n   返回类别
      • multi:softprob   num_class=n  返回概率
    • rank:pairwise

  • eval_metric

    • 回归任务(默认rmse)

      • rmse--均方根误差
      • mae--平均绝对误差
    • 分类任务(默认error)

      • auc--roc曲线下面积
      • error--错误率(二分类)
      • merror--错误率(多分类)
      • logloss--负对数似然函数(二分类)
      • mlogloss--负对数似然函数(多分类)
      •  
  • gamma

    • 惩罚项系数,指定节点分裂所需的最小损失函数下降值。
    • 调参:
  • alpha

    • L1正则化系数,默认为1
  • lambda

    • L2正则化系数,默认为1

 

代码主要函数:

-   载入数据:load_digits()
-   数据拆分:train_test_split()
-   建立模型:XGBClassifier()
-   模型训练:fit()
-   模型预测:predict()
-   性能度量:accuracy_score()
-   特征重要性:plot_importance()
复制代码

demo

# -*- coding: utf-8 -*-

# load module
from xgboost.sklearn import XGBClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from xgboost import plot_importance
import matplotlib.pyplot as plt

# load datasets
data = datasets.load_digits()
print(data.data)
print(data.target)

# data analysis
print(data.data.shape)
print(data.target.shape)

# data split
x_train, x_test, y_train, y_test = train_test_split(data.data,
                                                    data.target,
                                                    test_size=0.3,
                                                    random_state=33)

# fit model for train data
# 所有的参数都是放在XGBClassifier()这个类中,得到初始化的模型对象xgb_class_model
# 建立模型
xgb_class_model = XGBClassifier(learning_rate=0.1,
                      n_estimators=1000,  # 树的个数--1000棵树建立xgboost
                      max_depth=6,  # 树的深度
                      min_child_weight=1,  # 叶子节点最小权重
                      gamma=0.,  # 惩罚项中叶子结点个数前的参数
                      subsample=0.8,  # 随机选择80%样本建立决策树
                      colsample_btree=0.8,  # 随机选择80%特征建立决策树
                      objective='multi:softmax',  # 指定目标函数,多分类
                      scale_pos_weight=1,  # 解决样本个数不平衡的问题
                      random_state=27  # 随机数
                      )
# 训练模型
xgb_class_model.fit(x_train,
                    y_train,
                    eval_set=[(x_test, y_test)],
                    eval_metric="mlogloss",
                    early_stopping_rounds=10,
                    verbose=True)


# plot feature importance
fig, ax = plt.subplots(figsize=(15, 15))
plot_importance(model,
                height=0.5,
                ax=ax,
                max_num_features=64)
plt.show()

# make prediction for test data
y_pred = xgb_class_model.predict(x_test)  
# 这里是直接给出类型,predict_proba()函数是给出属于每个类别的概率。

# model evaluate
accuracy = accuracy_score(y_test, y_pred)
print("accuarcy: %.2f%%" % (accuracy * 100.0))

复制代码
分类:
人工智能
标签: