参与拿奖:本文已参与「新人创作礼」活动,一起开启掘金创作之路
ps:代码文末自取
1.相关概念
树回归一般步骤
准备数据:需要数值型数据,标称型数据应该映射为二值型数据
分析数据:可视化分析,以字典方式生成树
训练算法:大部分开销在于叶节点构建
测试算法:适用平方误差来测验
使用训练好的模型进行预测
优点:可以应对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据
CART利用二元切分法来处理连续型变量,二元切分就是大于某个数走左子树,否则走右子树
伪代码如下:
使用类来创建树节点
使用递归方式创建树
找到最佳待切分特征:
如果该节点无法再划分,保存该节点为叶子节点
执行二元切分
在右子树上调用createTree()方法
在左子树上调用createTree()方法
2.简单例子
2.1 CART算法
class treeNode():
def __init__(self,feat,val,right,left):
featureToSplitOn=feat
valueOfSplit=val
rightBranch=right
leftBranch=left
# CART算法实现
from numpy import*
# 加载数据集
def loadDataSet(fileName): #general function to parse tab -delimited floats
dataMat = [] #assume last column is target value
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float,curLine)) #map all elements to float()
dataMat.append(fltLine)
return dataMat
# 按照特征的值切分数据集
def binSplitDataSet(dataSet, feature, value):
#print(nonzero(dataSet[:,feature] <= value)[0])
#print(dataSet[[1,2,3],:])
#print(list(nonzero(dataSet[:,feature] > value)[0]))
mat0 = dataSet[list(nonzero(dataSet[:,feature] > value)[0]),:]
mat1 = dataSet[list(nonzero(dataSet[:,feature] <= value)[0]),:]
return mat0,mat1
# 测试
def testT():
# 4维单位矩阵
testMat=mat(eye(4))
print(testMat)
mat0,mat1=binSplitDataSet(testMat, 0, 0.5)
print(mat1,mat0)
testT()
2.2 将CART用于回归
计算混乱度使用平方误差衡量
构建树
对于寻找最佳切分点的函数伪代码如下
对于每个特征:
初始化当前最小误差为正无穷
对于每个特征值:
将数据集切分为两份
计算切分误差
如果当前误差小于当前误差:
将当前切分点设置为最佳切分,并更新最小误差
返回最佳切分的特征和阈值
# 划分数据
# ops(tolS,tolN)两个参数分别表示容许的误差下降值,切分的最少样本数
### 两个参数用来控制函数的停止时机
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
#if all the target variables are the same value: quit and return value
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
return None, leafType(dataSet)
m,n = shape(dataSet)
#the choice of the best feature is driven by Reduction in RSS error from mean
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
soleSet=[]
for j in range(dataSet[:,featIndex].shape[0]):
temp=dataSet[:,featIndex][j,featIndex]
if temp not in soleSet:
soleSet.append(temp)
# print(soleSet)
for splitVal in soleSet:
#print(dataSet,featIndex, splitVal)
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#if the decrease (S-bestS) is less than a threshold don't do the split
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit cond 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
return None, leafType(dataSet)
return bestIndex,bestValue #returns the best feature to split on,and the value used for that split
# 构造树
# 参数分别为:数据集,构造叶子节点的函数、误差计算函数、ops是一个包含构造树所需其他参数的元组
# ops(tolS,tolN)两个参数分别表示容许的误差下降值,切分的最少样本数
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
# 找到最佳切分特征和值
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
if feat == None:
return val #if the splitting hit a stop condition return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
#print(retTree,type(retTree))
return retTree
2.3 预剪枝与后剪枝
# 预剪枝
def preCut():
myData=loadDataSet(r'./data/ex2.txt')
myMat=mat(myData)
print("使用停止条件",createTree(myMat))
print("修改停止条件",createTree(myMat,ops=(10000,4)))
# preCut()
# 从上述情况可以看出超参对结果影响较大,结果具有不可预测性
# 后剪枝,不需要指定参数
# 伪代码如下:
### 基于已有的树切分测试数据:
### 如果存在任一子集是一棵树,则在该子集上递归剪枝过程
### 计算将当前两个叶子节点合并后的误差
### 计算不合并的误差
### 比较两种误差大小,做出决策
# 判断该节点是否为字典(树),即判断是否需要剪枝
def isTree(obj):
return (type(obj).__name__=='dict')
3.使用GUI设置参数
3.1 运行效果
3.2 代码
# 对matplotlib与tkinter进行集成展示
# 进行集成展示
from numpy import *
from tkinter import *
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
def reDraw(tolS,tolN):
reDraw.f.clf() # clear the figure
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree=createTree(reDraw.rawDat, modelLeaf,\
modelErr, (tolS,tolN))
yHat = createForeCast(myTree, reDraw.testDat, \
modelTreeEval)
else:
myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))
yHat = createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0].tolist(), reDraw.rawDat[:,1].tolist(), s=5) #use scatter for data set
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
reDraw.canvas.draw()
def getInputs():
try: tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0,'10')
try: tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0,'1.0')
return tolN,tolS
def drawNewTree():
tolN,tolS = getInputs()#get values from Entry boxes
reDraw(tolS,tolN)
root=Tk()
reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(loadDataSet(r'data/sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
root.mainloop()
4.小结
数据集中经常包含一些复杂的相关关系,使得输入数据与目标变量之间呈现非线性关系
对于这种复杂的关系可以利用树来对预测值进行分段,包括分段常数或者分段直线
若叶节点使用的模型是分段常数就是回归树,如果是线性回归方程就称之为模型树
CART算法可以构建二元树,并处理离散型或者连续型数据的切分
预剪枝:在树的构建过程中就进行剪枝,需要用户定义一些参数
后剪枝:在树构建完毕之后进行剪枝
5.参考资料
[1] 机器学习实战
[2] 书籍源码
[3] jupyter版本
[4] 本节代码