决策树 / 判断树 / Decision Tree

252 阅读3分钟

一、概述

  1. 是什么:
    • 一种用于分类和回归的非参数监督学习方法
  2. 目标:
    • 通过学习从数据特征中推断出的简单决策规则,创建一个能够预测目标变量值的模型
  3. 特点:
    • 树越深,决策规则越复杂,模型越拟合数据
    • 决策树的生成是一个递归过程
  4. 优点:
    • 能通过可视化方式呈现
    • 数据准备工作较少
    • 能处理数值和多输出问题、分类数据
  5. 缺点:
    • 可能产生过拟合问题
    • 不稳定
    • 预测不平滑也不连续
  6. 度量样本集合纯度的常用指标
    • 信息熵: image.png
    • 基尼值: image.png
  7. 决策树的划分属性选择指标:这些对决策树的尺寸有较大影响,但对泛化影响有限
    • 信息增益: image.png
    • 增益率:对可取值较少的属性有所偏好 image.png
    • 基尼指数: image.png
  8. 处理“过拟合”的方法:
    • 剪枝:
      • 基本策略:
        • 预剪枝:
          • 是什么:
            • 在决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点
        • 后剪枝:
          • 是什么:
            • 先从训练集生成一颗完整的决策树,然后自底向上地对非叶节点进行考察,若将该结点对应地子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点
  9. 判断决策树泛化性能是否提升:
    • 留出法
  10. 对连续属性的结点进行划分:
    • 连续属性离散化技术:
      • 二分法 image.png
  11. 多变量决策树
    • 算法举例:
      • OCI
  12. 决策树学习算法里最著名的代表:
    • ID3
    • CART

二、基于决策树算法的模型

  1. DecisonTreeClassifier
    • 是什么:
      • 一个用于分类问题的模型
    • 接收的输入:
      • 两个数组:
        • 一个形状为(n_samples,n_features)的数组X:用于保存训练样本
        • 一个形状为(n_samples)的整型数组Y,用于保存训练样本的类别标签
    • 用途:
      • 可以将数据集分成不同的类别或标签
    • 特点:
      • 在训练时,寻找最佳的特征和分割点,将数据划分成不同的类别,使得每个子集内的样本尽可能属于同一类别
    • 适用于:
      • 离散型输出(比如:预测某个物体是哪个类别)
  2. DecisionTreeRegressor
    • 是什么:
      • 用于回归问题的模型
    • 接受的输入同DecisionTreeClassifier
    • 用途:
      • 预测连续型输出
    • 特点:
      • 在训练时,寻找最佳的特征和分割点,使得每个子集内的样本的输出值尽可能接近
    • 适用于:
      • 连续型输出,(比如:预测一个数值)

三、SKlearn里如何构建可视化决策树

# 加载数据集
from sklearn.datasets import 数据集名称
# 加载决策树分类器
from sklearn.tree import DecisionTreeClassifier
# 加载绘图工具
import matplotlib.pyplot as plt
# 加载导出决策树为文本的工具
from sklearn.tree import export_text
# 保存数据集
自定义数据集名称 = 加载的数据集名称()
# X,y两个数组
X, y = 自定义数据集名称.data, 自定义数据集名称.target
# 创建分类器实例
分类器名称 = tree.DecisionTreeClassifier()
# 将分类器特征和目标标签传递给决策树分类器,用于训练决策树模型
分类器名称 = 分类器名称.fix(X,y)
# 预测样本类别
分类器名称.predict()
# 预测每个类别的概率
类别概率.predict_proba()
# 使用plot_tree函数绘制决策树
tree.plot_tree(分类器名称, feature_names = 数据集.feature_names, calss_names = 数据集.target_names, filled = True)
plt.show()
# 使用export_tree函数以文本格式导出决策树
any_name = export_text(分类器名称, feature_names=数据集名称.feature_name)
print(any_name)

参考资源:
  • 西瓜书周志华《机器学习》--决策树
  • SKlearn官网--决策树
  • chatgpt