决策树算法
决策树是一种常用的监督学习算法,广泛用于分类和回归问题。它的结构类似于一棵树,每个节点代表一个特征的分裂条件,分支代表决策规则,叶节点代表预测结果。
1. 决策树的基本概念
- 节点:
- 根节点:树的起点,包含所有样本数据。
- 内部节点:分裂条件,用于将数据划分为不同的子集。
- 叶节点:预测结果,代表数据的类别或回归值。
- 分裂:
- 数据集在每个内部节点根据某个特征及其值划分为不同的子集。
- 目标是使子集尽可能“纯”(即包含单一类别的数据)。
- 路径:
- 从根节点到叶节点的决策序列。
2. 决策树的核心思想
决策树通过递归地分裂数据,找到使子集“最纯”的特征和分裂点。
- 目标:最大化子集的纯度。
- 评价指标:常见的指标包括:
- 信息增益(Information Gain):
- 测量分裂后系统不确定性的减少程度。
- 基于熵的概念,公式为: 其中,是数据集 D的熵, 是按特征 A的值 v划分的数据子集。
- 基尼指数(Gini Index):
- 测量分类的不纯度,值越小越纯。
- 定义为: 表示第 k类的样本比例。
- 均方误差(Mean Squared Error, MSE):
- 用于回归树,最小化每个分裂后的误差平方和。
- 信息增益(Information Gain):
3. 决策树的生成
- 初始化:
- 将整个数据集作为根节点。
- 递归分裂:
- 对每个节点:
- 计算每个特征的分裂指标(如信息增益或基尼指数)。
- 选择指标最大的特征和分裂点。
- 根据分裂点将数据划分为子集。
- 如果满足停止条件(如达到最大深度或纯度),停止分裂。
- 对每个节点:
- 停止条件:
- 达到最大深度。
- 子集中样本数小于最小样本数。
- 没有可以进一步分裂的特征。
4. 决策树的剪枝
剪枝的目的是防止过拟合,使模型更具泛化能力。
预剪枝(Pre-Pruning)
- 在树生成过程中,提前停止分裂。
- 常用条件:
- 达到最大树深度(
maxDepth)。 - 分裂后信息增益小于阈值(
minInfoGain)。 - 节点样本数少于最小值(
minInstancesPerNode)。
- 达到最大树深度(
后剪枝(Post-Pruning)
- 生成完整决策树后,从底向上删除不必要的节点。
- 通过交叉验证评估剪枝对模型性能的影响。
- 方法:
- 成本复杂度剪枝(Cost Complexity Pruning):
- 平衡模型的复杂度与误分类率: 是误分类率,是叶节点数,是正则化参数。
- 成本复杂度剪枝(Cost Complexity Pruning):
5. 决策树的优缺点
优点
- 简单直观,易于理解和解释。
- 可处理数值型和类别型数据。
- 不需要特征缩放(如归一化)。
- 对异常值和缺失值不敏感。
- 可处理非线性数据。
缺点
- 容易过拟合,特别是树深度较大时。
- 对于复杂数据,单一决策树性能可能欠佳(可以使用集成方法,如随机森林)。
- 分裂时可能受到噪声影响(通过剪枝缓解)。
6. 决策树的变种
- 分类树(Classification Tree):
- 输出类别标签。
- 使用信息增益或基尼指数选择分裂点。
- 回归树(Regression Tree):
- 输出连续值。
- 使用均方误差(MSE)选择分裂点。
- CART(Classification And Regression Tree):
- 同时支持分类和回归。
- 使用二分法分裂特征。
7. 决策树的实际应用
- 分类任务:
- 垃圾邮件分类。
- 客户行为预测。
- 医学诊断。
- 回归任务:
- 房价预测。
- 销售量预测。
- 特征重要性分析:
- 决策树可以计算特征的重要性,帮助理解数据。
# -*- coding: utf-8 -*-
# @Time : 2024/10/2 12:00
# @Author : pblh123@126.com
# @File : pyspark_classfication.py
# @Describe : todo
import os
import warnings
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer, VectorIndexer, IndexToString
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType
from utils.window_Utils import windows_enviroment_set
# 过滤警告信息
warnings.simplefilter("ignore")
windows_enviroment_set()
def decistiontree():
# 指定CSV文件路径
path = r"D:\PycharmProjects\2024\pyspark\datas\iris.txt"
# 使用Spark读取CSV文件并推断列的数据类型,然后重命名列
df_raw = spark.read.option("inferSchema", "true").option("sep",","). \
csv(path).toDF("c0", "c1", "c2", "c3", "label")
# 将列的数据类型转换为Double
df_double = df_raw.select(
df_raw["c0"].cast(DoubleType()),
df_raw["c1"].cast(DoubleType()),
df_raw["c2"].cast(DoubleType()),
df_raw["c3"].cast(DoubleType()),
df_raw["label"]
)
# 创建VectorAssembler并设置输入列和输出列
assembler = VectorAssembler(inputCols=["c0", "c1", "c2", "c3"], outputCol="features")
# 使用VectorAssembler将特征列组装成特征向量
df = assembler.transform(df_double).select("features", "label")
# 创建StringIndexer用于标签列
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(df)
# 创建VectorIndexer用于特征列
featureIndexer = VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(df)
# 创建IndexToString用于将预测的标签转换回原始标签
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel", labels=labelIndexer.labels)
# 随机拆分数据集为训练集和测试集
trainingData, testData = df.randomSplit([0.7, 0.3])
# 创建DecisionTreeClassifier
dtClassifier = DecisionTreeClassifier() \
.setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
# 创建Pipeline,设置阶段
dtPipeline = Pipeline(stages=[labelIndexer, featureIndexer, dtClassifier, labelConverter])
# 训练Pipeline模型
dtPipelineModel = dtPipeline.fit(trainingData)
# 进行预测
dtPredictions = dtPipelineModel.transform(testData)
# 选择需要展示的列
selected = dtPredictions.select("predictedLabel", "label", "features")
# 显示前10行
selected.show(10)
# 创建多类别分类评估器并计算准确率
evaluator = MulticlassClassificationEvaluator(labelCol="indexedLabel", predictionCol="prediction")
dtAccuracy = evaluator.evaluate(dtPredictions)
print("Model Accuracy:", dtAccuracy)
# 获取DecisionTreeClassificationModel
treeModelClassifier = dtPipelineModel.stages[2]
# 打印学习到的分类树模型
print("Learned classification tree model:\n" + treeModelClassifier.toDebugString)
if __name__ == '__main__':
# 1. 创建SparkSession
spark = SparkSession.builder \
.appName("PysparkmlliblogisticRegression_spark341") \
.master("local[2]") \
.getOrCreate()
# sparkcontext
sc = spark.sparkContext
spark.sparkContext.setLogLevel("WARN")
# 2. spark业务代码
decistiontree()
# 3. 关闭sparkSession, sparkcontext
sc.stop()
spark.stop()