文章首发于微信公众号:AlgorithmDeveloper,专注机器学习与Python,编程与算法,还有生活。
1.前言
「决策树」| Part2—Python实现之构建决策树中我们已经可以基于给定数据集训练出决策树模型,只不过是以字典方式表示决策树,决策树直观、易于理解的优点完全体现不出来。因此,这篇文章的目的就是将训练出的决策树模型以树状图形表示。
给定数据集:
字典形式决策树模型:
{'人品': {'好': '见 ', '差': {'富有': {'没钱': '不见', '有钱': {'外貌': {'漂亮': '见 ', '丑': '不见'}}}}}}
2.获取决策树的叶节点数及深度
为了使绘制出的决策树图形不因树的节点、深度的增减而变得畸形,因此利用决策树的叶子节点个数以及树的深度将x轴、y轴平均切分,从而使树状图平均分布在画布上。
#获取决策树叶节点个数
def getNumLeafs(tree):
numLeafs = 0
#获取第一个节点的分类特征
firstFeat = list(tree.keys())[0]
#得到firstFeat特征下的决策树(以字典方式表示)
secondDict = tree[firstFeat]
#遍历firstFeat下的每个节点
for key in secondDict.keys():
#如果节点类型为字典,说明该节点下仍然是一棵树,此时递归调用getNumLeafs
if type(secondDict[key]).__name__== 'dict':
numLeafs += getNumLeafs(secondDict[key])
#否则该节点为叶节点
else:
numLeafs += 1
return numLeafs
#获取决策树深度
def getTreeDepth(tree):
maxDepth = 0
#获取第一个节点分类特征
firstFeat = list(tree.keys())[0]
#得到firstFeat特征下的决策树(以字典方式表示)
secondDict = tree[firstFeat]
#遍历firstFeat下的每个节点,返回子树中的最大深度
for key in secondDict.keys():
#如果节点类型为字典,说明该节点下仍然是一棵树,此时递归调用getTreeDepth,获取该子树深度
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
3.绘制决策树
3.1绘制节点
#绘制决策树
import matplotlib.pyplot as plt
def createPlot(tree):
#定义一块画布,背景为白色
fig = plt.figure(1, facecolor='white')
#清空画布
fig.clf()
#不显示x、y轴刻度
xyticks = dict(xticks=[],yticks=[])
#frameon:是否绘制坐标轴矩形
createPlot.pTree = plt.subplot(111, frameon=False, **xyticks)
#计算决策树叶子节点个数
plotTree.totalW = float(getNumLeafs(tree))
#计算决策树深度
plotTree.totalD = float(getTreeDepth(tree))
#最近绘制的叶子节点的x坐标
plotTree.xOff = -0.5/plotTree.totalW
#当前绘制的深度:y坐标
plotTree.yOff = 1.0
#(0.5,1.0)为根节点坐标
plotTree(tree,(0.5,1.0),'')
plt.show()
#定义决策节点以及叶子节点属性:boxstyle表示文本框类型,sawtooth:锯齿形;fc表示边框线粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.5")
leafNode = dict(boxstyle="round4", fc="0.5")
#定义箭头属性
arrow_args = dict(arrowstyle="
#nodeText:要显示的文本;centerPt:文本中心点,即箭头所在的点;parentPt:指向文本的点;nodeType:节点属性
#ha='center',va='center':水平、垂直方向中心对齐;bbox:方框属性
#arrowprops:箭头属性
#xycoords,textcoords选择坐标系;axes fraction-->0,0是轴域左下角,1,1是右上角
def plotNode(nodeText, centerPt, parentPt, nodeType):
createPlot.pTree.annotate(nodeText, xy=parentPt, xycoords="axes fraction",
xytext=centerPt, textcoords='axes fraction',
va='center',ha='center',bbox=nodeType, arrowprops=arrow_args)
def plotMidText(centerPt,parentPt,midText):
xMid = (parentPt[0] - centerPt[0])/2.0 + centerPt[0]
yMid = (parentPt[1] - centerPt[1])/2.0 + centerPt[1]
createPlot.pTree.text(xMid, yMid, midtext)
plotNode函数一次绘制的是一个箭头与一个节点,plotMidText函数绘制的是直线中点上的文本。
3.2递归绘制决策树
递归绘制决策树的整体思路如下:
(1)绘制当前节点;
(2)如果当前节点的子节点不是叶子节点,则递归;
(3)如果当前节点的子节点是叶子节点,则绘制。
def plotTree(tree, parentPt, nodeTxt):
#计算叶子节点个数
numLeafs = getNumLeafs(tree)
#获取第一个节点特征
firstFeat = list(tree.keys())[0]
#计算当前节点的x坐标
centerPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#绘制当前节点
plotMidText(centerPt,parentPt,nodeTxt)
plotNode(firstFeat,centerPt,parentPt,decisionNode)
secondDict = tree[firstFeat]
#计算绘制深度
plotTree.yOff -= 1.0/plotTree.totalD
for key in secondDict.keys():
#如果当前节点的子节点不是叶子节点,则递归
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key],centerPt,str(key))
#如果当前节点的子节点是叶子节点,则绘制该叶节点
else:
#plotTree.xOff在绘制叶节点坐标的时候才会发生改变
plotTree.xOff += 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff,plotTree.yOff),centerPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),centerPt,str(key))
plotTree.yOff += 1.0/plotTree.totalD
根据决策树的叶子节点数和深度来平均切分画布,并且x、y轴的总长度为1,如下图所示:
原谅我的画图水平
3.2.1在createPlot函数中:
plotTree.totalW :表示叶子节点个数,因此上图中每两个叶子节点之间的距离为:1/plotTree.totalW;
plotTree.totalD :表示决策树深度;
plotTree.xOff:表示最近绘制的叶子节点x坐标,在绘制叶节点时其值才会更新;其初始值为图中虚线圆圈位置,这样在以后确定叶子节点位置时可以直接加整数倍的1/plotTree.totalW;
plotTree.yOff = 1.0 :表示当前绘制的深度,其值初始化为根节点y坐标。
3.2.2在plotTree函数中:
#计算当前节点的x坐标
centerPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
在确定当前节点x坐标时,只需确定当前节点下的叶节点个数,其x坐标即为叶节点所占距离的一半:float(numLeafs)/2.0/plotTree.totalW;
由于plotTree.xOff初始值为-0.5/plotTree.totalW,因此当前节点x坐标还需加上0.5/plotTree.totalW。
4.决策树可视化
#决策树节点文本可以以中文显示
import matplotlib as mpl
mpl.rcParams["font.sans-serif"] = ["Microsoft YaHei"]
mpl.rcParams['axes.unicode_minus'] = False
#创建数据集
def createDataSet():
dataSet = [['有钱','好','漂亮','见 '],
['有钱','差','漂亮','见 '],
['有钱','差','丑','不见'],
['没钱','好','丑','见 '],
['没钱','差','漂亮','不见'],
['没钱','好','漂亮','见 ']]
labels = ['富有','人品','外貌']
return dataSet, labels
dataSet, dataLabels = createDataSet()
#创建决策树
myTree = createDecideTree(dataSet,dataLabels)
print(myTree)
#绘制决策树
createPlot(myTree)
字典形式表示决策树:
{'人品': {'好': '见 ', '差': {'富有': {'没钱': '不见', '有钱': {'外貌': {'漂亮': '见 ', '丑': '不见'}}}}}}
树状图形决策树:
5.使用决策树算法
在已知对方有钱,人品差,长得漂亮后,利用前面训练的决策树做出决策,见或不见?!
#使用决策树进行分类
def classify(tree,feat,featValue):
firstFeat = list(tree.keys())[0]
secondDict = tree[firstFeat]
featIndex = feat.index(firstFeat)
for key in secondDict.keys():
if featValue[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],feat,featValue)
else:
classLabel = secondDict[key]
return classLabel
feat = ['富有','人品','外貌']
featValue = ['有钱','差','漂亮']
print(classify(myTree,feat,featValue))
决策结果:
见
6.存储决策树模型
构建决策树消耗的时间还是很可观的,尤其在数据量大的时候,因此,当训练完决策树模型后有必要将其保存下来,以便后续使用。使用Python模块的pickle序列化对象可以解决这个问题,序列化对象可以在磁盘上保存对象,在需要时将其读取出来。
#保存决策树模型
import pickle
def saveTree(tree, fileName):
fw = open(fileName,'wb')
pickle.dump(tree, fw)
fw.close()
#加载决策树模型
def loadTree(fileName):
fr = open(fileName,'rb')
return pickle.load(fr)
saveTree((myTree),'myTree.txt')
print(loadTree('myTree.txt'))
{'人品': {'差': {'富有': {'有钱': {'外貌': {'丑': '不见', '漂亮': '见 '}}, '没钱': '不见'}}, '好': '见 '}}
Coding Your Ambition!