持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第14天,点击查看活动详情
前面了解了SVM了,它可以用于执行分类和回归、线性和非线性任务。现在我们继续看看什么是决策树?
决策树
首先,决策树也能够执行分类和回归,还能够进行多输出(SVM不行)。他功能强大,能够拟合复杂的数据集。话不多说,我们边用边看,让我们在可视化的下慢慢探究。
这里的数据集还是需要Scikit-learn.datasets下的load_iris()导入鸢尾花后,使用我们的决策树来预测一下,下面是核心代码:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
# 这边只需要花瓣长度
X = iris.data[:, 2:]
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X, y)
现在啊,这些代码简直都是太熟悉了,我就简单讲讲DecisionTreeClassifier()里的参数:
criteria这个是用来区分实例的方法,可选{"gini", "entropy", "log_loss"},默认是ginimax_depth生成的树的最大深度
我一般对不太了解的模型,都是自己先运行代码,查看是怎么样的,如果不能直接看出来,就去查看官网文档,毕竟那里是讲的最详细的,同时配有不少例子可以查看。最后实在还不行的,就配合这一些其他人分享的讲解文档再看看。
使用dot工具查看树
接下来,对于树的可视化和之间的有一点不同,画出决策树需要借助一个在Graphviz包下的dot命令行工具。
我们先生成一个.dot文件,这个是用于画树的一种语法:
from graphviz import Source
from sklearn.tree import export_graphviz
export_graphviz(
# 我们训练好的模型
tree_clf,
# 地址和生成的文件名
out_file=os.path.join(IMAGES_PATH, "iris_tree.dot"),
# 需要知道特征名和分类的类名
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
# 设置为True表示可能有多类和多输出
filled=True
)
上面的作用就是在IMAGES_PATH目录下,生成一个iris_tree.dot文件。该文件下是用dot语法,我们还不能直接看,需要借助dot工具将其转成图片。
我们可以通过anaconda先安装好graphviz,本人的系统环境是windows,先使用是conda创建了一个虚拟环境,安装好后,我们就可以在anaconda目录\envs\环境名\Script下找到dot.bat这样我们就可以使用这个工具将我们的生成的dot文件变成pdf或者png。我的位置如下:
图1 dot位置
接下来我们进入anaconda prompt,进入我们的虚拟环境下,我这边是tf2,然后cd到我们dot文件保存的地址下,执行以下命令:
图2 执行dot命令
最后就会生成如下图:
图3 iris数据集的决策树
预测
我们看一下图3的决策树,他都做了啥:
-
先看深度为0的根节点,他其实有一个判断的-
花瓣长度是否大于2.45厘米,满足就是山鸢尾(setosa),接下来就是一个基尼系数(gini),这个下面将。然后是样品数(samples=150),然后value,他就是告诉我们现在有3个类,在当前节点下每个类的训练的实例数量。然后将判断的结果分成了两个子节点。 -
接下来进入深度1的左子节点,他是满足
花瓣长度小于等于2.45厘米的结果,他有50个样品,全都是来自value[0],预测为山鸢尾(setosa),所以gini=0。 -
深度1的右子节点则是满足
花瓣长度大于于2.45厘米,同时如果花瓣宽度小于等于1.75厘米,满足就是变色鸢尾(versicolor),有100个样本实例。 -
再看深度2的左子节点,就是满足3中的要求,但是它49个实例来自value[1],5个来自value[2],所以gini=0.168,预测为变色鸢尾(versicolor)。
-
最后就是深度2的右子节点,是不满足的要求的46个样品,他也分别来自value[1]的1个和value[2]的45个,gini=0.043,预测为维吉尼亚鸢尾(virginica)。
最后我们就可以看到使用DecisionTreeClassifier()将iris数据集分成了三类。
哦,对了。那上面的gini系数是个啥呢?其实节点的gini是用来衡量节点的不纯度的,让我们看一下他的公式:
就拿上面第五步的gini举个例子,我们的p就是当前类实例占总样本实例的比例,比如value[1]类的,那么总的