Machine Learning in Action 读书笔记---第9章 树回归

本文深入探讨了CART算法在树回归中的应用,包括回归树与模型树的构建、剪枝策略以及Python实现。通过对比实验展示了模型树在解释性和预测效果上的优势,并提供了使用Tkinter创建交互式GUI的示例,以动态调整树的复杂度并观察预测结果。

Machine Learning in Action 读书笔记

第9章 树回归



一、树回归

1.树回归的优缺点

  • 优点:可以对复杂和非线性的数据建模
  • 缺点:结果不易理解
  • 适用数据类型:数值型和标称型数据

2.树回归的一般方法

  1. 收集数据:采用任意方法收集数据
  2. 准备数据:需要数值型的数据,标称型数据应该映射成二值型数据
  3. 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
  4. 训练算法:大部分时间都花费在叶节点树模型的构建上
  5. 测试算法:使用测试数据上的 R 2 R^2 R2值来分析模型的效果
  6. 使用算法:使用训练出的树做预测,预测结果还可以用来做很多事情

3.回归树与模型树

本章将构建两种树:

  • 回归树(regression tree):其每个叶节点包含单个值
  • 模型树(model tree):其每个叶节点包含一个线性方程

树的存储依然使用字典存储,该字典中包含以下4个元素:

  • 待切分的特征。
  • 待切分的特征值。
  • 右子树。当不再需要切分的时候,也可以是单个值
  • 左子树。

二、CART算法

CART(Classification And Regression Trees,分类回归树):是一种树构建算法,它使用二元切分来处理连续型变量。
二元切分法:每次把数据集切分成两份,如果数据的某个特征值等于切分所要求的值,那么这些数据就进入树的左子树,反之进入树的右子树。

第三章中的决策树的切分使用ID3(迭代二分器)算法,但是ID3算法不能直接处理连续型特征。

CART算法的实现:

def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):  # 参数为:数据集,待切分特征,该特征的某个值
    # mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :][0]
    # mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :][0]
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]  # 使用数据过滤方式切割数据集
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

# 创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # 默认为回归树
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)  # 选择最好切分特征和特征值
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat  # 待切分特征
    retTree['spVal'] = val   # 待切分特征值
    lSet, rSet = binSplitDataSet(dataSet, feat, val) # 按照找出的最好切分特征和特征值继续切分为左右子树
    retTree['left'] = createTree(lSet, leafType, errType, ops) # 左子树
    retTree['right'] = createTree(rSet, leafType, errType, ops) # 右子树
    return retTree

回归树的切分函数:

'''CART算法的辅助函数:回归树的切分函数'''

#创建叶节点的函数
def regLeaf(dataSet):
    return mean(dataSet[:, -1])

#总方差计算函数
def regErr(dataSet):
    return var(dataSet[:, -1]) * shape(dataSet)[0]  # var函数用于计算均方差(均方差 * 样本个数 = 总方差)

'''CART算法'''

# 该函数用于找到数据的最佳二元切分方式:1.用最佳方式切分数据,2.生成相应的叶节点
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,ops由用户指定,用于控制函数的停止时机
    tolS = ops[0]; tolN = ops[1]  # tolS 是允许的误差下降值,tolN是切分的最少样本数
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 前面用于统计不同剩余特征值的数目,如果为1就不用切分直接返回
        return None, leafType(dataSet)  # 如果所有值相当则退出 1,退出后直接创建叶节点
    m, n = shape(dataSet)  # 计算当前数据集的大小
    S = errType(dataSet)  # 计算当前数据集的误差,S将用于与新切分误差进行对比,来检查新切分能够降低误差
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): # 这里进行了修改
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
        return None, leafType(dataSet)   # 如果误差减少不大则退出  2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)  # 如果切分出的数据集很小则退出 3,即子集大小小于用户定义的参数tolN
    return bestIndex, bestValue  # 上述3个提前终止条件都不满足,返回数据集切分的最好位置和最好值(最佳切分特征和阈值)

其中,上述代码中,有两个用户可控制参数tolS和tolN:

  • tolS:是允许的最小误差下降值,小于该值则不切分直接退出
  • tolN:是允许切分的最少样本数,小于定义的最小样本数则不切分直接退出

三、树剪枝

通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。有两种形式的剪枝方法,预剪枝(prepruning)和后剪枝(postpruning)。

1.预剪枝

在函数chooseBestSplit()中的提前终止条件,实际上就是在进行一种所谓的预剪枝操作。其中的三个终止条件是:

  1. 如果切分数据集后效果提升不大,那么就不应该进行切分操作而直接创建叶节点
  2. 如果两个切分后的某个子集的大小小于用户定义的参数tolN,那么也不应切分
  3. 统计剩余特征值的数目,数目为1,就不需要切分而直接返回

2.后剪枝

使用后剪枝方法需要将数据集分为测试集和训练集。使用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。
函数prune()的伪代码:

基于已有的树切分测试数据:
	如果存在任一子集是一颗树,则在该子集递归剪枝过程
	计算将当前两个叶节点合并后的误差
	计算不合并的误差
	如果合并会降低误差的话,就将叶节点合并

python代码:

'''回归树剪枝函数'''

# 用于判断是不是叶子节点
def isTree(obj):
    return (type(obj).__name__ == 'dict')

#返回树平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

# 修剪函数
def prune(tree, testData):  # 待剪枝的树和剪枝所需的测试数据
    if shape(testData)[0] == 0:
        return getMean(tree)  # 没有测试数据就对树进行塌陷处理,即返回树平均值
    if (isTree(tree['right']) or isTree(tree['left'])):  # 如果树枝不是树,修建它
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet) #这部分是一个递归
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)#这部分是一个递归

    if not isTree(tree['left']) and not isTree(tree['right']):# 如果都是叶子节点,合并它们
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0 # 合并
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

四、模型树

模型树的可解释性优于回归树。

'''模型树的叶节点生成函数'''

# 该函数用于将数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
    m, n = shape(dataSet)
    X = mat(ones((m, n))); Y = mat(ones((m, 1)))
    X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, \n\
                        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

# 当数据不需要切分时,该函数负责生成叶节点的模型,最后返回回归系数
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

# 该函数用于在给定的数据集上计算误差
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))  # 返回预测值和真值之间的平方误差

五、回归树与模型树的比较

用树回归进行预测的代码:

'''用树回归进行预测的代码'''

# 对 回归树 叶节点进行预测
def regTreeEval(model, inDat): # 只使用一个输入参数,但为了和模型树保持一致,写两个
    return float(model)

# 对 模型树 叶节点进行预测
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1, n + 1)))
    X[:, 1:n + 1] = inDat
    return float(X * model)

# 自顶向下遍历整棵树,直到命中叶节点为止。一旦到达叶节点,它就会在输入数据上调用modelEval()函数,该函数的默认值是回归树
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

# 多次调用treeForeCast函数,能够以向量形式返回一组测试值
def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

创建回归树和模型树,计算 R 2 R^2 R2来比较:

	trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
    testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))
    myTree = createTree(trainMat, ops=(1, 20))  # 创建一颗回归树
    yHat = createForeCast(myTree, testMat[:, 0])
    R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0, 1]
    print(R2) #0.964085231822215
    myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20)) # 创建一颗模型树
    yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval)
    R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0,1]
    print(R2)  #0.964085231822215  #计算相关系数

从上面可以看出,模型树更优一些。

六、使用python的Tkinter库创建GUI

1.用Tkinter创建GUI

Tkinter的GUI由一些小部件(widget)组成。比如文本框(Text Box)、按钮(Button)、标签(Label)和复选按钮(Check Button)等对象。当调用小部件的.grid()方法时,就等于把该小部件的位置告诉了布局管理器(Geometry Manager)。


用于构建树管理器界面的Tkinter小部件:

def reDraw(tolS, tolN):
    pass
def drawNewTree():
    pass
if __name__ == "__main__":
    '''构建树管理器界面的tkinter小部件'''
    root = Tk()
    # Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
    Label(root, text="tolN").grid(row=1, column=0)
    tolNentry = Entry(root)  # Entry为文本输入框
    tolNentry.grid(row=1, column=1)
    tolNentry.insert(0,'10')
    Label(root, text="tolS").grid(row=2, column=0)
    tolSentry = Entry(root)
    tolSentry.grid(row=2, column=1)
    tolSentry.insert(0, '1.0')
    Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3) # rowspan和columnspan的值告诉布局管理器是否允许一个小部件跨行或跨列
    chkBtnVar = IntVar()  #IntVar 为按钮整数值,为了读取checkbutton的状态创建的变量
    chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar) # Checkbutton 复选按钮
    chkBtn.grid(row=3, column=0, columnspan=2)
    reDraw.rawDat = mat(loadDataSet('sine.txt'))
    reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
    reDraw(1.0, 10)

    root.mainloop()

绘制出的tkinter窗口如下:
在这里插入图片描述
从上面代码可知,reDraw()和drawNewTree()函数内部还没有功能,会在后面实现。

2.集成Matplotlib和Tkinter

通过修改matplotlib后端(仅在GUI上)达到在tkinter的GUI上绘图的目的。
matplotlib的构建程序包含一个前端,同时也创建了一个后端,通过改变后端可以将图像绘制在PNG、PDF、SVG等格式的文件上。下面将设置后端为TkAgg。

matplotlib和Tkinter的代码集成:

import matplotlib
matplotlib.use('TkAgg') # 设定matplotlib的后端为TkAgg
# 下面两个import声明将TkAgg和 matplotlib图链接起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

# 用于绘制树
def reDraw(tolS, tolN):
    reDraw.f.clf()  # 清空之前的图像,使得前后两个图像不会重叠
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get(): # 检查复选框是否被选中,确定基于tolS和tolN参数构建模型树还是回归树
        if tolN < 2:
            tolN = 2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:, 0].tolist(), reDraw.rawDat[:, 1].tolist(), s=5)  # 这里加了tolist()
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    reDraw.canvas.draw()

# getInputs函数试图理解用户的输入并防止程序崩溃
def getInputs():
    try:
        tolN = int(tolNentry.get()) # tolN期望的输入是整数,.get()方法用于得到用户输入的文本
    except:
        tolN = 10
        print("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get()) # tolS期望的输入是浮点数
    except:
        tolS = 1.0
        print("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

# draNewTree函数实现两个功能:1.调用getInputs方法得到输入框的值。2.利用该值调用reDraw方法生成一个漂亮的图
# 每点击一次reDraw按钮就会调用一次这个函数
def drawNewTree():
    tolN, tolS = getInputs()  # get values from Entry boxes
    reDraw(tolS, tolN)

在主函数中:

if __name__ == "__main__":
    '''构建树管理器界面的tkinter小部件'''
    root = Tk()

    reDraw.f = Figure(figsize=(5, 4), dpi=100)
    reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
    reDraw.canvas.draw()
    reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

    # Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
    Label(root, text="tolN").grid(row=1, column=0)
    tolNentry = Entry(root)  # Entry为文本输入框
    tolNentry.grid(row=1, column=1)
    tolNentry.insert(0,'10')
    Label(root, text="tolS").grid(row=2, column=0)
    tolSentry = Entry(root)
    tolSentry.grid(row=2, column=1)
    tolSentry.insert(0, '1.0')
    Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3) # rowspan和columnspan的值告诉布局管理器是否允许一个小部件跨行或跨列
    chkBtnVar = IntVar()  #IntVar 为按钮整数值,为了读取checkbutton的状态创建的变量
    chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar) # Checkbutton 复选按钮
    chkBtn.grid(row=3, column=0, columnspan=2)
    reDraw.rawDat = mat(loadDataSet('sine.txt'))
    reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
    reDraw(1.0, 10)

    root.mainloop()

绘出的窗口如下:
在这里插入图片描述
将模型树复选框勾选后,点击reDraw:
在这里插入图片描述
还可以通过改变tolN和tolS来修改树的复杂程度:
在这里插入图片描述
在这里插入图片描述

七、全部代码

'''regTrees.py'''
from numpy import *

def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):  # 参数为:数据集,待切分特征,该特征的某个值
    # mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :][0]
    # mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :][0]
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]  # 使用数据过滤方式切割数据集
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

'''CART算法的辅助函数:回归树的切分函数'''
#创建叶节点的函数
def regLeaf(dataSet):
    return mean(dataSet[:, -1])
#总方差计算函数
def regErr(dataSet):
    return var(dataSet[:, -1]) * shape(dataSet)[0]  # var函数用于计算均方差(均方差 * 样本个数 = 总方差)

'''CART算法'''
# 该函数用于找到数据的最佳二元切分方式:1.用最佳方式切分数据,2.生成相应的叶节点
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,ops由用户指定,用于控制函数的停止时机
    tolS = ops[0]; tolN = ops[1]  # tolS 是允许的误差下降值,tolN是切分的最少样本数
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 前面用于统计不同剩余特征值的数目,如果为1就不用切分直接返回
        return None, leafType(dataSet)  # 如果所有值相当则退出 1,退出后直接创建叶节点
    m, n = shape(dataSet)  # 计算当前数据集的大小
    S = errType(dataSet)  # 计算当前数据集的误差,S将用于与新切分误差进行对比,来检查新切分能够降低误差
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): # 这里进行了修改
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
        return None, leafType(dataSet)   # 如果误差减少不大则退出  2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)  # 如果切分出的数据集很小则退出 3,即子集大小小于用户定义的参数tolN
    return bestIndex, bestValue  # 上述3个提前终止条件都不满足,返回数据集切分的最好位置和最好值(最佳切分特征和阈值)

# 创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat  # 待切分特征
    retTree['spVal'] = val   # 待切分特征值
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops) # 左子树
    retTree['right'] = createTree(rSet, leafType, errType, ops) # 右子树
    return retTree

'''回归树剪枝函数'''
# 用于判断是不是叶子节点
def isTree(obj):
    return (type(obj).__name__ == 'dict')
#返回树平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

# 修剪函数
def prune(tree, testData):  # 待剪枝的树和剪枝所需的测试数据
    if shape(testData)[0] == 0:
        return getMean(tree)  # 没有测试数据就对树进行塌陷处理,即返回树平均值
    if (isTree(tree['right']) or isTree(tree['left'])):  # 如果树枝不是树,修建它
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)

    if not isTree(tree['left']) and not isTree(tree['right']):# 如果都是叶子节点,合并它们
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

'''模型树的叶节点生成函数'''
# 该函数用于将数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
    m, n = shape(dataSet)
    X = mat(ones((m, n))); Y = mat(ones((m, 1)))
    X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, \n\
                        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y
# 当数据不需要切分时,该函数负责生成叶节点的模型,最后返回回归系数
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws
# 该函数用于在给定的数据集上计算误差
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))  # 返回预测值和真值之间的平方误差

'''用树回归进行预测的代码'''
# 对 回归树 叶节点进行预测
def regTreeEval(model, inDat): # 只使用一个输入参数,但为了和模型树保持一致,写两个
    return float(model)

# 对 模型树 叶节点进行预测
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1, n + 1)))
    X[:, 1:n + 1] = inDat
    return float(X * model)

# 自顶向下遍历整棵树,直到命中叶节点为止。一旦到达叶节点,它就会在输入数据上调用modelEval()函数,该函数的默认值是回归树
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

# 多次调用treeForeCast函数,能够以向量形式返回一组测试值
def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

if __name__ == '__main__':
    testMat = mat(eye(4))
    mat0, mat1 = binSplitDataSet(testMat, 1, 0.5) # 将testMat使用第一个特征,按特征值0.5进行划分
    # print(mat0)
    # print(mat1)

    myDat = loadDataSet('ex00.txt')
    myMat = mat(myDat)
    # print(myMat)
    myTree = createTree(myMat)
    # print(myTree) #{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}

    myDat1 = loadDataSet('ex0.txt')
    myMat1 = mat(myDat1)
    myTree1 = createTree(myMat1)
    # print(myTree1)

    myDat2 = loadDataSet('ex2.txt')
    myMat2 = mat(myDat1)
    myTree2 = createTree(myMat2)  # ops=(1,4)
    # print(myTree2)  # 有多个叶子节点,因为对误差的容忍值小
    myTree2_2 = createTree(myMat2, ops=(100,4))
    # print(myTree2_2)  #{'spInd': 1, 'spVal': 0.39435, 'left': 2.9675496160000003, 'right': 0.39728045333333334}

    myTree = createTree(myMat2, ops=(0, 1))
    myDatTest = loadDataSet('ex2test.txt')
    myMat2Test = mat(myDatTest)
    prune(myTree, myMat2Test)

    myMat2 = mat(loadDataSet('exp2.txt'))
    myTree = createTree(myMat2, modelLeaf, modelErr, (1, 10))
    # print(myTree)

    trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
    testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))
    myTree = createTree(trainMat, ops=(1, 20))  # 创建一颗回归树
    yHat = createForeCast(myTree, testMat[:, 0])
    R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0, 1]
    # print(R2) #0.964085231822215
    myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20)) # 创建一颗模型树
    yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval)
    R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0,1]
    # print(R2)  #0.964085231822215  #计算相关系数
'''treeExplore.py'''
from tkinter import *
from numpy import *
from regTrees import *

import matplotlib
matplotlib.use('TkAgg') # 设定matplotlib的后端为TkAgg
# 下面两个import声明将TkAgg和 matplotlib图链接起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

# 用于绘制树
def reDraw(tolS, tolN):
    reDraw.f.clf()  # 清空之前的图像,使得前后两个图像不会重叠
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get(): # 检查复选框是否被选中,确定基于tolS和tolN参数构建模型树还是回归树
        if tolN < 2:
            tolN = 2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:, 0].tolist(), reDraw.rawDat[:, 1].tolist(), s=5)  # 这里加了tolist()
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    reDraw.canvas.draw()

# getInputs函数试图理解用户的输入并防止程序崩溃
def getInputs():
    try:
        tolN = int(tolNentry.get()) # tolN期望的输入是整数,.get()方法用于得到用户输入的文本
    except:
        tolN = 10
        print("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get()) # tolS期望的输入是浮点数
    except:
        tolS = 1.0
        print("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

# draNewTree函数实现两个功能:1.调用getInputs方法得到输入框的值。2.利用该值调用reDraw方法生成一个漂亮的图
# 每点击一次reDraw按钮就会调用一次这个函数
def drawNewTree():
    tolN, tolS = getInputs()  # get values from Entry boxes
    reDraw(tolS, tolN)

'''小试tkinter'''
# root = Tk()
# myLabel = Label(root, text="Hello World")
# myLabel.grid()
# root.mainloop()

# def reDraw(tolS, tolN):
#     pass
# def drawNewTree():
#     pass

if __name__ == "__main__":
    '''构建树管理器界面的tkinter小部件'''
    root = Tk()

    reDraw.f = Figure(figsize=(5, 4), dpi=100)
    reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
    reDraw.canvas.draw()
    reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

    # Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
    Label(root, text="tolN").grid(row=1, column=0)
    tolNentry = Entry(root)  # Entry为文本输入框
    tolNentry.grid(row=1, column=1)
    tolNentry.insert(0,'10')
    Label(root, text="tolS").grid(row=2, column=0)
    tolSentry = Entry(root)
    tolSentry.grid(row=2, column=1)
    tolSentry.insert(0, '1.0')
    Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3) # rowspan和columnspan的值告诉布局管理器是否允许一个小部件跨行或跨列
    chkBtnVar = IntVar()  #IntVar 为按钮整数值,为了读取checkbutton的状态创建的变量
    chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar) # Checkbutton 复选按钮
    chkBtn.grid(row=3, column=0, columnspan=2)
    reDraw.rawDat = mat(loadDataSet('sine.txt'))
    reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
    reDraw(1.0, 10)

    root.mainloop()
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值