开发也能看懂的大模型:决策树

1,446 阅读8分钟

前情提要:后端开发自学AI知识,内容比较浅显,偏实战;仅适用于入门了解,解决日常沟通障碍。

决策树(Decision Tree)是一种树形结构模型,用于解决分类和回归问题。它通过一系列的规则将数据逐步划分,最终实现目标预测。决策树模型直观、易解释,是许多机器学习方法(如随机森林和梯度提升树)的基础。

1. 决策树的基本概念

  • 节点(Node)

    • 根节点:整个树的起点,包含全部数据。
    • 内部节点:每次分裂的节点,根据某个特征值进行划分。
    • 叶节点:树的终点,表示最终的分类或预测值。
  • 分裂规则: 每个内部节点选择一个特征和一个划分条件,将数据分为两部分,以最大化分类或回归效果。

  • 树的构建过程

    1. 从根节点开始,对数据集选择最佳特征进行划分。
    2. 对划分后的每个子集递归重复上述步骤。
    3. 当满足停止条件时(如达到树的最大深度或叶节点数据纯度足够高),停止分裂。

2. 决策树分类示例

以一个简单的分类问题为例:

  • 数据集包含特征:天气(晴、多云、雨)、温度(高、适中、低)、湿度(高、正常)、风(弱、强)。
  • 目标:预测是否适合外出活动。

树的部分结构:

根节点:天气
       ├── 晴 -> 湿度
       │      ├── 高 -> 不适合
       │      └── 正常 -> 适合
       ├── 多云 -> 适合
       └── 雨 -> 风
              ├── 弱 -> 适合
              └── 强 -> 不适合

通过上述树结构可以清晰地进行预测,比如:

  • 如果天气为“晴”,湿度为“高”,那么预测结果为“不适合”。

3. 决策树的构建核心

3.1 划分标准

决策树通过某种标准选择最佳特征进行分裂。

image.png

3.2 停止条件

  1. 达到预设的最大深度。
  2. 叶节点样本数少于某个阈值。
  3. 数据纯度达到一定标准。

4. 决策树的优缺点

优点

  • 易解释:树形结构直观,易于理解。
  • 无特征缩放要求:对输入数据无需归一化或标准化。
  • 处理非线性数据:通过分裂特征空间,能很好地适配非线性关系。
  • 支持多种数据类型:可以处理离散和连续数据。

缺点

  • 过拟合:如果树过深,可能会记住训练集中的噪声。
  • 对小变化敏感:输入数据的微小变化可能导致树结构的显著变化。
  • 无法建模复杂关系:当数据关系复杂时,单棵决策树可能性能不足。

5. 决策树的优化方法

  1. 剪枝(Pruning)

    • 预剪枝:在构建树的过程中限制树的深度或叶节点的最小样本数。
    • 后剪枝:构建完全树后,通过剪掉冗余节点来降低复杂度。
  2. 特征选择优化

    • 通过增加正则化项(如限制信息增益比或基尼系数阈值)避免过拟合。
  3. 集成方法

    • 随机森林:通过集成多棵决策树,提升模型的泛化能力。
    • 梯度提升树(GBDT) :利用多棵树的误差递归提升预测能力。

决策树应用场景

决策树是一种直观且高效的机器学习模型,广泛应用于分类回归规则推导等领域。以下是决策树常见的应用场景及其特点。

1. 分类问题

决策树常用于对离散目标变量进行分类,适合以下场景:

  • 医疗诊断

    • 通过患者症状和检查结果预测疾病类型。
    • 例如:根据是否发烧、咳嗽、胸部疼痛等信息预测是否患有肺炎。
  • 金融风控

    • 判断信用卡申请人是否具备还款能力(信用评级)。
    • 例如:根据用户收入、债务、还款历史等信息分类“高风险”或“低风险”。
  • 客户细分

    • 根据用户行为、兴趣等属性,将用户划分为不同的群体,用于精准营销。
    • 例如:根据购买记录预测用户是“新客户”还是“忠实客户”。
  • 垃圾邮件检测

    • 基于邮件的内容特征(关键词出现频率、发件人地址等)判断邮件是否为垃圾邮件。

2. 回归问题

决策树也可以处理连续目标变量的预测,常见场景包括:

  • 房地产价格预测

    • 根据房屋面积、位置、卧室数量等特征预测房价。
  • 能源消耗预测

    • 根据天气、使用时间、设备类型等特征预测某区域的能耗。
  • 股票市场分析

    • 通过历史交易数据、市场指标等特征预测股票价格走势。

3. 决策支持

决策树具有规则清晰、解释性强的特点,适合用于帮助人们制定决策:

  • 法律和规则自动化

    • 用于模拟法律规则,自动判定案件是否符合某些法律条文。
  • 项目管理

    • 根据项目预算、团队规模、风险系数等信息,决定项目是否启动。
  • 教育领域

    • 帮助制定学生学习策略。根据学习成绩、出勤率等信息,预测学生是否需要额外辅导。

4. 时间敏感场景

决策树具有较高的预测速度,适合时间敏感的应用:

  • 实时推荐系统

    • 在电商平台上,根据用户浏览记录实时推荐商品。
  • 交通管理

    • 在智慧交通系统中,根据实时流量数据预测拥堵区域,并调整信号灯。

案例:鸢尾花分类

数据集概述

鸢尾花数据集包含150个样本,每个样本有四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。该数据集分为三类,分别代表三种不同的鸢尾花:Iris setosa、Iris versicolor 和 Iris virginica。

数据集地址:github.com/scikit-lear…

每个样本有四个特征,这些特征都是实数,分别为:

  1. 花萼长度(sepal length in cm)
  2. 花萼宽度(sepal width in cm)
  3. 花瓣长度(petal length in cm)
  4. 花瓣宽度(petal width in cm)

这些特征描述了鸢尾花的外观属性,是连续值。

image.png

该数据集的目标是将样本分为三类鸢尾花物种之一

  1. Setosa
  2. Versicolour
  3. Virginica

这三个类别的样本在数据集中均匀分布,每类各有50个样本。

步骤:

  1. 加载数据集: 鸢尾花数据集是一个开源的数据集,可以在很多机器学习库中直接获取,比如在Python的scikit-learn库中。
  2. 数据预处理: 通常情况下,鸢尾花数据集不需要太多预处理,因为它已经被很好地整理过。但是,对于其他数据集,你可能需要进行诸如缺失值处理和标准化等步骤。
  3. 构建决策树模型: 使用scikit-learn库中的DecisionTreeClassifier来创建决策树模型。
  4. 训练模型: 将数据集分为训练集和测试集,然后用训练集来训练模型。
  5. 评估模型: 使用测试集来评估模型的性能,例如计算准确率。
  6. 可视化决策树: 可以使用graphviz库来可视化决策树,以便更好地理解决策过程。

Python代码示例:

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn import tree
import matplotlib.pyplot as plt

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树分类器
clf = DecisionTreeClassifier()

# 训练模型
clf.fit(X_train, y_train)

# 预测
y_pred = clf.predict(X_test)

# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

# 可视化决策树
plt.figure(figsize=(12,8))
tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
# plt.show()
plt.savefig('output.png')  # 保存为文件

结果:

image.png

image.png

可视化结果解释

可视化决策树中通常会显示以下信息:

  1. 节点(Node) :每个节点代表数据集的一个子集。根节点包含整个数据集,然后根据特征的某个值将其分成两个或多个子节点。
  2. 分裂条件(Split Condition) :在每个非叶节点上,你会看到一个特征及其阈值。这是分裂数据的条件。例如,如果节点上写着petal length (cm) <= 2.45,这意味着数据根据花瓣长度是否小于或等于2.45来分裂。
  3. 样本数(Samples) :每个节点都会标示通过该节点的数据样本数量。比如,samples = 50表示该节点包含了50个样本。
  4. 类别分布(Value) :对于分类问题,每个节点还会显示不同类别的样本数量。例如,value = [10, 40]意味着这个节点中有10个样本属于第一类,有40个样本属于第二类。
  5. 基尼不纯度(Gini Impurity) :这是衡量节点中样本混合程度的一种指标。基尼指数越低,节点越“纯”,意味着大多数样本都是同一类的。
  6. 叶节点(Leaf Node) :没有子节点的节点就是叶节点。叶节点给出最终的决策或类别。在分类任务中,这表示为具有最多样本的类别。

假设某个节点的描述如下:

  • petal length (cm) <= 2.45
  • gini = 0.5
  • samples = 100
  • value = [50, 50]
  • class = setosa

这可以解读为:

  • 数据在这个节点被分裂成两部分,依据是花瓣长度是否小于或等于2.45厘米。
  • 基尼系数为0.5,表明数据可能均匀地分布在不同的类别之间。
  • 该节点处理了100个样本。
  • 有50个样本属于第一类(如setosa),50个样本属于其他类别。
  • 在这种情况下,虽然显示class = setosa,但实际上这个节点并不纯。

通过决策树,你可以逐步追踪决策路径,从而理解模型如何对输入数据进行分类。