机器学习之树回归算法

187 阅读5分钟
参与拿奖:本文已参与「新人创作礼」活动,一起开启掘金创作之路
ps:代码文末自取
1.相关概念
树回归一般步骤
准备数据:需要数值型数据,标称型数据应该映射为二值型数据
分析数据:可视化分析,以字典方式生成树
训练算法:大部分开销在于叶节点构建
测试算法:适用平方误差来测验
使用训练好的模型进行预测


优点:可以应对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据
CART利用二元切分法来处理连续型变量,二元切分就是大于某个数走左子树,否则走右子树

伪代码如下:

使用类来创建树节点
使用递归方式创建树
找到最佳待切分特征:
     如果该节点无法再划分,保存该节点为叶子节点
     执行二元切分
     在右子树上调用createTree()方法
     在左子树上调用createTree()方法
2.简单例子

2.1 CART算法

    
class treeNode():
    def __init__(self,feat,val,right,left):
        featureToSplitOn=feat
        valueOfSplit=val
        rightBranch=right
        leftBranch=left
    
# CART算法实现
from numpy import*

# 加载数据集
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine)) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat
    
# 按照特征的值切分数据集
def binSplitDataSet(dataSet, feature, value):
    #print(nonzero(dataSet[:,feature] <= value)[0])
    #print(dataSet[[1,2,3],:])
    #print(list(nonzero(dataSet[:,feature] > value)[0]))
    mat0 = dataSet[list(nonzero(dataSet[:,feature] > value)[0]),:]
    mat1 = dataSet[list(nonzero(dataSet[:,feature] <= value)[0]),:]
    return mat0,mat1


# 测试
def testT():
    # 4维单位矩阵
    testMat=mat(eye(4))
    print(testMat)
    mat0,mat1=binSplitDataSet(testMat, 0, 0.5)
    print(mat1,mat0)
testT()

2.2 将CART用于回归

计算混乱度使用平方误差衡量
构建树
对于寻找最佳切分点的函数伪代码如下
     对于每个特征:
         初始化当前最小误差为正无穷
             对于每个特征值:
                 将数据集切分为两份
                    计算切分误差
                如果当前误差小于当前误差:
                    将当前切分点设置为最佳切分,并更新最小误差
                返回最佳切分的特征和阈值
# 划分数据
# ops(tolS,tolN)两个参数分别表示容许的误差下降值,切分的最少样本数
### 两个参数用来控制函数的停止时机
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        soleSet=[]
        for j in range(dataSet[:,featIndex].shape[0]):
            temp=dataSet[:,featIndex][j,featIndex]
            if temp not in soleSet:
                soleSet.append(temp)
        # print(soleSet)
        for splitVal in soleSet:
            #print(dataSet,featIndex, splitVal)
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS: 
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue    #returns the best feature to split on,and the value used for that split
    
# 构造树
# 参数分别为:数据集,构造叶子节点的函数、误差计算函数、ops是一个包含构造树所需其他参数的元组
# ops(tolS,tolN)两个参数分别表示容许的误差下降值,切分的最少样本数
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    # 找到最佳切分特征和值
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None:
        return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    #print(retTree,type(retTree))
    return retTree

2.3 预剪枝与后剪枝


# 预剪枝
def preCut():
    myData=loadDataSet(r'./data/ex2.txt')
    myMat=mat(myData)
    print("使用停止条件",createTree(myMat))
    print("修改停止条件",createTree(myMat,ops=(10000,4)))
# preCut()
# 从上述情况可以看出超参对结果影响较大,结果具有不可预测性    
    
# 后剪枝,不需要指定参数
# 伪代码如下:
### 基于已有的树切分测试数据:
    ### 如果存在任一子集是一棵树,则在该子集上递归剪枝过程
    ### 计算将当前两个叶子节点合并后的误差
    ### 计算不合并的误差
    ### 比较两种误差大小,做出决策

# 判断该节点是否为字典(树),即判断是否需要剪枝
def isTree(obj):
    return (type(obj).__name__=='dict')
3.使用GUI设置参数

3.1 运行效果

image.png

3.2 代码


# 对matplotlib与tkinter进行集成展示
# 进行集成展示
from numpy import *

from tkinter import *

import matplotlib

matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS,tolN):
    reDraw.f.clf()        # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree=createTree(reDraw.rawDat, modelLeaf,\
                                   modelErr, (tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat, \
                                       modelTreeEval)
    else:
        myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0].tolist(), reDraw.rawDat[:,1].tolist(), s=5) #use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
    reDraw.canvas.draw()
    
def getInputs():
    try: tolN = int(tolNentry.get())
    except: 
        tolN = 10 
        print("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0,'10')
    try: tolS = float(tolSentry.get())
    except: 
        tolS = 1.0 
        print("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0,'1.0')
    return tolN,tolS

def drawNewTree():
    tolN,tolS = getInputs()#get values from Entry boxes
    reDraw(tolS,tolN)
    
root=Tk()

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = mat(loadDataSet(r'data/sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
               
root.mainloop()
4.小结
数据集中经常包含一些复杂的相关关系,使得输入数据与目标变量之间呈现非线性关系
对于这种复杂的关系可以利用树来对预测值进行分段,包括分段常数或者分段直线
若叶节点使用的模型是分段常数就是回归树,如果是线性回归方程就称之为模型树
CART算法可以构建二元树,并处理离散型或者连续型数据的切分
预剪枝:在树的构建过程中就进行剪枝,需要用户定义一些参数
后剪枝:在树构建完毕之后进行剪枝
5.参考资料

[1] 机器学习实战

[2] 书籍源码

[3] jupyter版本

[4] 本节代码