Machine Learning in Action 读书笔记
第9章 树回归
文章目录
一、树回归
1.树回归的优缺点
- 优点:可以对复杂和非线性的数据建模
- 缺点:结果不易理解
- 适用数据类型:数值型和标称型数据
2.树回归的一般方法
- 收集数据:采用任意方法收集数据
- 准备数据:需要数值型的数据,标称型数据应该映射成二值型数据
- 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
- 训练算法:大部分时间都花费在叶节点树模型的构建上
- 测试算法:使用测试数据上的 R 2 R^2 R2值来分析模型的效果
- 使用算法:使用训练出的树做预测,预测结果还可以用来做很多事情
3.回归树与模型树
本章将构建两种树:
- 回归树(regression tree):其每个叶节点包含单个值
- 模型树(model tree):其每个叶节点包含一个线性方程
树的存储依然使用字典存储,该字典中包含以下4个元素:
- 待切分的特征。
- 待切分的特征值。
- 右子树。当不再需要切分的时候,也可以是单个值
- 左子树。
二、CART算法
CART(Classification And Regression Trees,分类回归树):是一种树构建算法,它使用二元切分来处理连续型变量。
二元切分法:每次把数据集切分成两份,如果数据的某个特征值等于切分所要求的值,那么这些数据就进入树的左子树,反之进入树的右子树。
第三章中的决策树的切分使用ID3(迭代二分器)算法,但是ID3算法不能直接处理连续型特征。
CART算法的实现:
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat
def binSplitDataSet(dataSet, feature, value): # 参数为:数据集,待切分特征,该特征的某个值
# mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :][0]
# mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :][0]
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] # 使用数据过滤方式切割数据集
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
# 创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # 默认为回归树
feat, val = chooseBestSplit(dataSet, leafType, errType, ops) # 选择最好切分特征和特征值
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat # 待切分特征
retTree['spVal'] = val # 待切分特征值
lSet, rSet = binSplitDataSet(dataSet, feat, val) # 按照找出的最好切分特征和特征值继续切分为左右子树
retTree['left'] = createTree(lSet, leafType, errType, ops) # 左子树
retTree['right'] = createTree(rSet, leafType, errType, ops) # 右子树
return retTree
回归树的切分函数:
'''CART算法的辅助函数:回归树的切分函数'''
#创建叶节点的函数
def regLeaf(dataSet):
return mean(dataSet[:, -1])
#总方差计算函数
def regErr(dataSet):
return var(dataSet[:, -1]) * shape(dataSet)[0] # var函数用于计算均方差(均方差 * 样本个数 = 总方差)
'''CART算法'''
# 该函数用于找到数据的最佳二元切分方式:1.用最佳方式切分数据,2.生成相应的叶节点
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,ops由用户指定,用于控制函数的停止时机
tolS = ops[0]; tolN = ops[1] # tolS 是允许的误差下降值,tolN是切分的最少样本数
if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 前面用于统计不同剩余特征值的数目,如果为1就不用切分直接返回
return None, leafType(dataSet) # 如果所有值相当则退出 1,退出后直接创建叶节点
m, n = shape(dataSet) # 计算当前数据集的大小
S = errType(dataSet) # 计算当前数据集的误差,S将用于与新切分误差进行对比,来检查新切分能够降低误差
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): # 这里进行了修改
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet) # 如果误差减少不大则退出 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet) # 如果切分出的数据集很小则退出 3,即子集大小小于用户定义的参数tolN
return bestIndex, bestValue # 上述3个提前终止条件都不满足,返回数据集切分的最好位置和最好值(最佳切分特征和阈值)
其中,上述代码中,有两个用户可控制参数tolS和tolN:
- tolS:是允许的最小误差下降值,小于该值则不切分直接退出
- tolN:是允许切分的最少样本数,小于定义的最小样本数则不切分直接退出
三、树剪枝
通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。有两种形式的剪枝方法,预剪枝(prepruning)和后剪枝(postpruning)。
1.预剪枝
在函数chooseBestSplit()中的提前终止条件,实际上就是在进行一种所谓的预剪枝操作。其中的三个终止条件是:
- 如果切分数据集后效果提升不大,那么就不应该进行切分操作而直接创建叶节点
- 如果两个切分后的某个子集的大小小于用户定义的参数tolN,那么也不应切分
- 统计剩余特征值的数目,数目为1,就不需要切分而直接返回
2.后剪枝
使用后剪枝方法需要将数据集分为测试集和训练集。使用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。
函数prune()的伪代码:
基于已有的树切分测试数据:
如果存在任一子集是一颗树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并
python代码:
'''回归树剪枝函数'''
# 用于判断是不是叶子节点
def isTree(obj):
return (type(obj).__name__ == 'dict')
#返回树平均值
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
# 修剪函数
def prune(tree, testData): # 待剪枝的树和剪枝所需的测试数据
if shape(testData)[0] == 0:
return getMean(tree) # 没有测试数据就对树进行塌陷处理,即返回树平均值
if (isTree(tree['right']) or isTree(tree['left'])): # 如果树枝不是树,修建它
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet) #这部分是一个递归
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)#这部分是一个递归
if not isTree(tree['left']) and not isTree(tree['right']):# 如果都是叶子节点,合并它们
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
sum(power(rSet[:, -1] - tree['right'], 2))
treeMean = (tree['left'] + tree['right']) / 2.0 # 合并
errorMerge = sum(power(testData[:, -1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
四、模型树
模型树的可解释性优于回归树。
'''模型树的叶节点生成函数'''
# 该函数用于将数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
m, n = shape(dataSet)
X = mat(ones((m, n))); Y = mat(ones((m, 1)))
X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse, \n\
try increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws, X, Y
# 当数据不需要切分时,该函数负责生成叶节点的模型,最后返回回归系数
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws
# 该函数用于在给定的数据集上计算误差
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2)) # 返回预测值和真值之间的平方误差
五、回归树与模型树的比较
用树回归进行预测的代码:
'''用树回归进行预测的代码'''
# 对 回归树 叶节点进行预测
def regTreeEval(model, inDat): # 只使用一个输入参数,但为了和模型树保持一致,写两个
return float(model)
# 对 模型树 叶节点进行预测
def modelTreeEval(model, inDat):
n = shape(inDat)[1]
X = mat(ones((1, n + 1)))
X[:, 1:n + 1] = inDat
return float(X * model)
# 自顶向下遍历整棵树,直到命中叶节点为止。一旦到达叶节点,它就会在输入数据上调用modelEval()函数,该函数的默认值是回归树
def treeForeCast(tree, inData, modelEval=regTreeEval):
if not isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)
# 多次调用treeForeCast函数,能够以向量形式返回一组测试值
def createForeCast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = mat(zeros((m, 1)))
for i in range(m):
yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
创建回归树和模型树,计算 R 2 R^2 R2来比较:
trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))
myTree = createTree(trainMat, ops=(1, 20)) # 创建一颗回归树
yHat = createForeCast(myTree, testMat[:, 0])
R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0, 1]
print(R2) #0.964085231822215
myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20)) # 创建一颗模型树
yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval)
R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0,1]
print(R2) #0.964085231822215 #计算相关系数
从上面可以看出,模型树更优一些。
六、使用python的Tkinter库创建GUI
1.用Tkinter创建GUI
Tkinter的GUI由一些小部件(widget)组成。比如文本框(Text Box)、按钮(Button)、标签(Label)和复选按钮(Check Button)等对象。当调用小部件的.grid()方法时,就等于把该小部件的位置告诉了布局管理器(Geometry Manager)。
用于构建树管理器界面的Tkinter小部件:
def reDraw(tolS, tolN):
pass
def drawNewTree():
pass
if __name__ == "__main__":
'''构建树管理器界面的tkinter小部件'''
root = Tk()
# Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root) # Entry为文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3) # rowspan和columnspan的值告诉布局管理器是否允许一个小部件跨行或跨列
chkBtnVar = IntVar() #IntVar 为按钮整数值,为了读取checkbutton的状态创建的变量
chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar) # Checkbutton 复选按钮
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
root.mainloop()
绘制出的tkinter窗口如下:

从上面代码可知,reDraw()和drawNewTree()函数内部还没有功能,会在后面实现。
2.集成Matplotlib和Tkinter
通过修改matplotlib后端(仅在GUI上)达到在tkinter的GUI上绘图的目的。
matplotlib的构建程序包含一个前端,同时也创建了一个后端,通过改变后端可以将图像绘制在PNG、PDF、SVG等格式的文件上。下面将设置后端为TkAgg。
matplotlib和Tkinter的代码集成:
import matplotlib
matplotlib.use('TkAgg') # 设定matplotlib的后端为TkAgg
# 下面两个import声明将TkAgg和 matplotlib图链接起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
# 用于绘制树
def reDraw(tolS, tolN):
reDraw.f.clf() # 清空之前的图像,使得前后两个图像不会重叠
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get(): # 检查复选框是否被选中,确定基于tolS和tolN参数构建模型树还是回归树
if tolN < 2:
tolN = 2
myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
else:
myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:, 0].tolist(), reDraw.rawDat[:, 1].tolist(), s=5) # 这里加了tolist()
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
reDraw.canvas.draw()
# getInputs函数试图理解用户的输入并防止程序崩溃
def getInputs():
try:
tolN = int(tolNentry.get()) # tolN期望的输入是整数,.get()方法用于得到用户输入的文本
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0, '10')
try:
tolS = float(tolSentry.get()) # tolS期望的输入是浮点数
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0, '1.0')
return tolN, tolS
# draNewTree函数实现两个功能:1.调用getInputs方法得到输入框的值。2.利用该值调用reDraw方法生成一个漂亮的图
# 每点击一次reDraw按钮就会调用一次这个函数
def drawNewTree():
tolN, tolS = getInputs() # get values from Entry boxes
reDraw(tolS, tolN)
在主函数中:
if __name__ == "__main__":
'''构建树管理器界面的tkinter小部件'''
root = Tk()
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
# Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root) # Entry为文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3) # rowspan和columnspan的值告诉布局管理器是否允许一个小部件跨行或跨列
chkBtnVar = IntVar() #IntVar 为按钮整数值,为了读取checkbutton的状态创建的变量
chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar) # Checkbutton 复选按钮
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
root.mainloop()
绘出的窗口如下:

将模型树复选框勾选后,点击reDraw:

还可以通过改变tolN和tolS来修改树的复杂程度:


七、全部代码
'''regTrees.py'''
from numpy import *
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat
def binSplitDataSet(dataSet, feature, value): # 参数为:数据集,待切分特征,该特征的某个值
# mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :][0]
# mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :][0]
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] # 使用数据过滤方式切割数据集
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
'''CART算法的辅助函数:回归树的切分函数'''
#创建叶节点的函数
def regLeaf(dataSet):
return mean(dataSet[:, -1])
#总方差计算函数
def regErr(dataSet):
return var(dataSet[:, -1]) * shape(dataSet)[0] # var函数用于计算均方差(均方差 * 样本个数 = 总方差)
'''CART算法'''
# 该函数用于找到数据的最佳二元切分方式:1.用最佳方式切分数据,2.生成相应的叶节点
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): # leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,ops由用户指定,用于控制函数的停止时机
tolS = ops[0]; tolN = ops[1] # tolS 是允许的误差下降值,tolN是切分的最少样本数
if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 前面用于统计不同剩余特征值的数目,如果为1就不用切分直接返回
return None, leafType(dataSet) # 如果所有值相当则退出 1,退出后直接创建叶节点
m, n = shape(dataSet) # 计算当前数据集的大小
S = errType(dataSet) # 计算当前数据集的误差,S将用于与新切分误差进行对比,来检查新切分能够降低误差
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): # 这里进行了修改
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet) # 如果误差减少不大则退出 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet) # 如果切分出的数据集很小则退出 3,即子集大小小于用户定义的参数tolN
return bestIndex, bestValue # 上述3个提前终止条件都不满足,返回数据集切分的最好位置和最好值(最佳切分特征和阈值)
# 创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat # 待切分特征
retTree['spVal'] = val # 待切分特征值
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops) # 左子树
retTree['right'] = createTree(rSet, leafType, errType, ops) # 右子树
return retTree
'''回归树剪枝函数'''
# 用于判断是不是叶子节点
def isTree(obj):
return (type(obj).__name__ == 'dict')
#返回树平均值
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
# 修剪函数
def prune(tree, testData): # 待剪枝的树和剪枝所需的测试数据
if shape(testData)[0] == 0:
return getMean(tree) # 没有测试数据就对树进行塌陷处理,即返回树平均值
if (isTree(tree['right']) or isTree(tree['left'])): # 如果树枝不是树,修建它
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):# 如果都是叶子节点,合并它们
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
sum(power(rSet[:, -1] - tree['right'], 2))
treeMean = (tree['left'] + tree['right']) / 2.0
errorMerge = sum(power(testData[:, -1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
'''模型树的叶节点生成函数'''
# 该函数用于将数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
m, n = shape(dataSet)
X = mat(ones((m, n))); Y = mat(ones((m, 1)))
X[:, 1:n] = dataSet[:, 0:n-1]; Y = dataSet[:, -1]
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse, \n\
try increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws, X, Y
# 当数据不需要切分时,该函数负责生成叶节点的模型,最后返回回归系数
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws
# 该函数用于在给定的数据集上计算误差
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2)) # 返回预测值和真值之间的平方误差
'''用树回归进行预测的代码'''
# 对 回归树 叶节点进行预测
def regTreeEval(model, inDat): # 只使用一个输入参数,但为了和模型树保持一致,写两个
return float(model)
# 对 模型树 叶节点进行预测
def modelTreeEval(model, inDat):
n = shape(inDat)[1]
X = mat(ones((1, n + 1)))
X[:, 1:n + 1] = inDat
return float(X * model)
# 自顶向下遍历整棵树,直到命中叶节点为止。一旦到达叶节点,它就会在输入数据上调用modelEval()函数,该函数的默认值是回归树
def treeForeCast(tree, inData, modelEval=regTreeEval):
if not isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)
# 多次调用treeForeCast函数,能够以向量形式返回一组测试值
def createForeCast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = mat(zeros((m, 1)))
for i in range(m):
yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
if __name__ == '__main__':
testMat = mat(eye(4))
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5) # 将testMat使用第一个特征,按特征值0.5进行划分
# print(mat0)
# print(mat1)
myDat = loadDataSet('ex00.txt')
myMat = mat(myDat)
# print(myMat)
myTree = createTree(myMat)
# print(myTree) #{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
myDat1 = loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
myTree1 = createTree(myMat1)
# print(myTree1)
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat1)
myTree2 = createTree(myMat2) # ops=(1,4)
# print(myTree2) # 有多个叶子节点,因为对误差的容忍值小
myTree2_2 = createTree(myMat2, ops=(100,4))
# print(myTree2_2) #{'spInd': 1, 'spVal': 0.39435, 'left': 2.9675496160000003, 'right': 0.39728045333333334}
myTree = createTree(myMat2, ops=(0, 1))
myDatTest = loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
prune(myTree, myMat2Test)
myMat2 = mat(loadDataSet('exp2.txt'))
myTree = createTree(myMat2, modelLeaf, modelErr, (1, 10))
# print(myTree)
trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))
myTree = createTree(trainMat, ops=(1, 20)) # 创建一颗回归树
yHat = createForeCast(myTree, testMat[:, 0])
R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0, 1]
# print(R2) #0.964085231822215
myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20)) # 创建一颗模型树
yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval)
R2 = corrcoef(yHat, testMat[:, 1], rowvar=0)[0,1]
# print(R2) #0.964085231822215 #计算相关系数
'''treeExplore.py'''
from tkinter import *
from numpy import *
from regTrees import *
import matplotlib
matplotlib.use('TkAgg') # 设定matplotlib的后端为TkAgg
# 下面两个import声明将TkAgg和 matplotlib图链接起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
# 用于绘制树
def reDraw(tolS, tolN):
reDraw.f.clf() # 清空之前的图像,使得前后两个图像不会重叠
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get(): # 检查复选框是否被选中,确定基于tolS和tolN参数构建模型树还是回归树
if tolN < 2:
tolN = 2
myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
else:
myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:, 0].tolist(), reDraw.rawDat[:, 1].tolist(), s=5) # 这里加了tolist()
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
reDraw.canvas.draw()
# getInputs函数试图理解用户的输入并防止程序崩溃
def getInputs():
try:
tolN = int(tolNentry.get()) # tolN期望的输入是整数,.get()方法用于得到用户输入的文本
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0, '10')
try:
tolS = float(tolSentry.get()) # tolS期望的输入是浮点数
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0, '1.0')
return tolN, tolS
# draNewTree函数实现两个功能:1.调用getInputs方法得到输入框的值。2.利用该值调用reDraw方法生成一个漂亮的图
# 每点击一次reDraw按钮就会调用一次这个函数
def drawNewTree():
tolN, tolS = getInputs() # get values from Entry boxes
reDraw(tolS, tolN)
'''小试tkinter'''
# root = Tk()
# myLabel = Label(root, text="Hello World")
# myLabel.grid()
# root.mainloop()
# def reDraw(tolS, tolN):
# pass
# def drawNewTree():
# pass
if __name__ == "__main__":
'''构建树管理器界面的tkinter小部件'''
root = Tk()
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
# Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root) # Entry为文本输入框
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3) # rowspan和columnspan的值告诉布局管理器是否允许一个小部件跨行或跨列
chkBtnVar = IntVar() #IntVar 为按钮整数值,为了读取checkbutton的状态创建的变量
chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar) # Checkbutton 复选按钮
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)
root.mainloop()
本文深入探讨了CART算法在树回归中的应用,包括回归树与模型树的构建、剪枝策略以及Python实现。通过对比实验展示了模型树在解释性和预测效果上的优势,并提供了使用Tkinter创建交互式GUI的示例,以动态调整树的复杂度并观察预测结果。
525

被折叠的 条评论
为什么被折叠?



