文章目录
决策树
简单了解决策树,如下图,正方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),从判断模块引出的左右箭头称作分支(branch),它可以到达另一个判断模块或终止模块。
上节学习的k-近邻算法可以完成很多分类任务,但是最大的缺点是无法给出数据的内在含义,决策树的主要优势就在于数据形式非常容易理解。
决策树很多任务都是为了数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取一系列规则,机器学习算法最终使用这些机器从数据集中创造的规则。
1.决策树的构造
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型
创建分支的伪代码函数createBranch()如下:
检测数据集中的每个子项是否属于同一分类
If so return 类标签;
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch并增加返回结果到分支节点中
return 分支节点
上面的伪代码createBranch是一个递归函数,在倒数第二行直接调用了自己。
决策树的一般流程:
- 收集数据:可以使用任何方法;
- 准备数据:数构造算法只适用于标称型数据,因此数值型数据必须离散化;
- 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
- 训练算法:构造树的数据结构;
- 测试算法:使用经验树计算错误率;
- 使用算法:此步骤可以适用于任何监督学习算法,而适用决策树可以更好地理解数据的内在含义。
下表的数据包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。
可以将这些动物分为两类:鱼类和非鱼类
编号 | 不浮出水面是否可以生存 | 是否有脚蹼 | 属于鱼类 |
---|---|---|---|
1 | 是 | 是 | 是 |
2 | 是 | 是 | 是 |
3 | 是 | 否 | 否 |
4 | 否 | 是 | 否 |
5 | 否 | 是 | 否 |
1.1信息增益
划分数据集的大原则是:将无序的数据变得更加有序。
多种方法划分数据集,各有优缺点,可以通过计算信息增益的方式评判,而集合信息的度量方式称为香农熵或熵。
熵定义为信息的期望值,如果待分类的事务可能划分在多个分类之中,则符号xi的信息定义为:(其中p(xi)是选择该分类的概率)
l
(
x
i
)
=
−
l
o
g
2
p
(
x
i
)
l(x_i)=-log_2p(x_i)
l(xi)=−log2p(xi)
为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下方公式:(其中n是分类的数目)
H
=
−
∑
i
=
1
n
p
(
x
i
)
l
o
g
2
p
(
x
i
)
H=-\sum_{i=1}^np(x_i)log_2p(x_i)
H=−i=1∑np(xi)log2p(xi)
下面给出Python计算信息熵的代码(在Pycharm中新建DT.py):
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
# 为所有可能分类创建字典
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 以2为底求对数
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt
接下来可以写入数据函数createDataSet():
def createDataSet():
# 上表中的数据集
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet, labels
然后再Pycharm中写main函数:
if __name__ == '__main__':
myDat, labels = createDataSet()
print(myDat)
print(calcShannonEnt(myDat))

熵越高,则混合的数据也越多,增加第三个名为maybe的分类,测试熵的变化:
# 在main函数中增加代码
myDat[0][-1] = 'maybe'
print(myDat)
print(calcShannonEnt(myDat))

1.2划分数据集
目前,已经度量数据集的无序程度(测量信息熵),接下来划分数据集:将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。
在DT.py中添加splitDataSet()函数:
# 三个形参:待划分的数据集、划分数据集的特征和特征返回值
def splitDataSet(dataSet, axis, value):
retDataSet = [] # 创建列表
# 遍历数据集中每个元素,符合要求的值添加到列表中
for featVec in dataSet:
# 用if语句抽取出符合特征的数据
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
然后在main函数中添加:
print(splitDataSet(myDat, 0, 1))
print(splitDataSet(myDat, 0, 0))

接下来将遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。在DT.py中添加chooseBestFeatureToSplit()函数:
# 满足条件:数据必须是一种由列表元素组成的列表,而且所有的列表元素都具有相同的数据长度
# 条件2:数据的最后一列或每个实例的最后一个元素是当前实例的类别标签
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
# 创建唯一的分类标签列表
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
# 计算每种划分方式的信息熵
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
# 计算最好的信息增益
bestInfoGain = infoGain
bestFeature = i
return bestFeature
然后在main函数中添加:
print(chooseBestFeatureToSplit(myDat))

从运行结果看:第0个特征是最好的用于划分数据集的特征,即可以按”不浮出水面是否可以生存“。
1.3递归构建决策树
目前已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,可以再次划分数据,很符合递归原则。
递归结束的条件是:程序遍历完所哟划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时需要决定如何定义该叶子节点,在这种情况下,通常会采用多数表决的方法决定该叶子节点的分类。
代码实现:
先在DT.py顶部增加一行代码:import operator, 然后添加majorityCnt()函数:
def majorityCnt(classList):
classCount = {} # 创建列表
# 字典对象存储了每个标签出现的频率
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
# 利用operator操作键值排序字典,并返回出现次数最多的分类名称
sortedClassCount = sorted(classCount.items(),
key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]
接下来继续创建树的函数:
# 输入两个参数:数据集和标签列表。标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但为了给出数据明确的含义,也输入这个参数。
def createTree(dataSet, labels):
# 创建了名位classList的列表变量
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时返回次数最多的
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
# 创建字典存储树的信息,这对于其后绘制树形图非常重要
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
# 得到列表包含的所有属性值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
# 遍历当前选择特征包含的所有属性值,递归调用,得到的返回值被插入到字典变量myTree中,函数终止时,字典中将会嵌套很多代表叶子节点信息的字典数据。
for value in uniqueVals:
# 复制类标签,并存储在新列表subLabels
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet\
(dataSet, bestFeat,value),subLabels)
return myTree
测试代码:
在main函数中添加:
print(createTree(myDat,labels))

变量myTree包含了很多代表树结构信息的嵌套字典,第一个关键字 no surfacing 是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典。第二个关键字是 no surfacing 特征划分的数据集,这些关键字的值是 no surfacing 节点的子节点。这些值可能是类标签,有可能是另一个数据字典。如果值是类标签,则该子节点是叶子节点;如果值是另一个数据字典,则子节点是一个判断节点,这个格式结构不断重复构成整棵树。
2.在Python中使用Matplotlib注解绘制树形图
上节已经正确地从数据集中构造树,接下来绘制图形,方便正确理解数据信息。
主要就是绘制如下图的决策树:

2.1Matplotlib注解
Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注解。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支持带箭头的划线工具。恰好可以使用该注解功能绘制树形图。
使用文本注解绘制树节点的实现:
在pycharm中新建DTPlotter.py, 写入下列代码:
import matplotlib.pyplot as plt
# 全局设置中文字体,为了输出中文
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 微软雅黑
# 定义文本框和箭头格式
decisionNone = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
# 绘制带箭头的注解
def plotNode(nodeTxt, centerrPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt,
xycoords = 'axes fraction',
xytext = centerrPt,
textcoords = 'axes fraction',
va = "center",
ha = "center",
bbox = nodeType,
arrowprops = arrow_args)
def createPlot():
# 创建一个新图形
fig = plt.figure(1, facecolor = 'white')
fig.clf() # 清空绘图区
createPlot.ax1 = plt.subplot(111, frameon = False)
# 绘制两个代表不同类型的树节点
plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNone)
plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show() # 输出显示图形
if __name__ == '__main__':
print(createPlot()) # 使用文本注解绘制树节点
在pycharm中ctrl+shift+F10,运行DTPlotter.py,输出图形:

2.2构建注解树
绘制一棵完整的树需要一些技巧,虽然有了x, y坐标,但如何放置所有的树节点却是个问题。因此必须知道有多少个叶节点,以便可以正确确定x轴的长度;还需要知道树有多少层,以便可以正确确定y轴的高度。故继续定义两个函数getNumLeafs()和getTreeDepth(),来获取叶节点的数目和树的层数。
继续在DTPlotter.py, 写入下列代码:
def getNumLeafs(myTree):
numLeafs = 0
# 第一个关键字是第一次划分数据集的类别标签
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
# 测试节点的数据类型是否为字典
# 如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用getNunLeafs()函数
# 遍历整棵树,累计叶子节点的个数,并返回该数值
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 和getNumLeafs()函数有点相似
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 预先存储树信息,避免每次测试代码都从数据中创建树的麻烦
def retrieveTree(i):
listOfTrees = [ {'no surfacing': {0: 'no', 1: {'flippers':
{0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers':
{0: {'head':{0:'no',1:'yes'}}, 1: 'no'}}}} ]
return listOfTrees[i]
而在main函数中添加:( 同时注释掉print(createPlot()) )
if __name__ == '__main__':
# print(createPlot()) # 使用文本注解绘制树节点
print(retrieveTree(1))
myTree = retrieveTree(0)
print(getNumLeafs(myTree))
print(getTreeDepth(myTree))
输出结果:

函数retrieveTree()主要用于测试,返回预定义的树结构,调用getNumLeafs()函数返回3,等于树0的叶子节点数;调用getTreeDepths()函数也能够正确返回数的层数。
但输出没有绘制一棵完整的树,尽管已经定义了createPlot()函数,但需要更新这部分代码,把树信息传进去这个函数。
更新createPlot()函数,并新增plotMidText()函数和plotTree()函数:
# 作用是计算tree的中间位置
# cntrpt起始位置,parentpt终止位置,txtstrin文本标签信息
def plotMidText(cntrPt, parentPt, txtString):
# 找到x和y的中间位置
xMid = (parentPt[0] - cntrPt[0]/2.0 + cntrPt[0])
yMid = (parentPt[1] - cntrPt[1]/2.0 + cntrPt[1])
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
# 计算子节点的坐标
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt) # 绘制线上的文字
plotNode(firstStr, cntrPt, parentPt, decisionNone) # 绘制节点
secondDict = myTree[firstStr]
# 每绘制一次图,将y的坐标减少1.0/plottree.totald,间接保证y坐标上深度的
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
# 创建一个新图形
fig = plt.figure(1, facecolor = 'white')
fig.clf() # 清空绘图区
axprops = dict(xticks = [], yticks = [])
# subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图
# frameon表示是否绘制坐标轴矩形
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# xOff和yOff用来记录当前要画的叶子节点的位置
# cntrPt记录当前要画的树的树根的结点位置
# x轴和y轴的范围都是[0.0~1.0],输出的图形是按比例绘制树形图,不担心变形,不建议用像素为单位绘制图形
接下来更新main函数:
if __name__ == '__main__':
# print(createPlot()) # 使用文本注解绘制树节点
# print(retrieveTree(1))
# myTree = retrieveTree(0)
# print(getNumLeafs(myTree))
# print(getTreeDepth(myTree))
myTree = retrieveTree(1)
print(createPlot(myTree))
输出结果:

3.测试和存储分类器
接下来在真实数据上使用决策树分类算法,验证决策树是否可以正确预测患者应该使用的隐形眼镜类型。
3.1测试算法:使用决策树执行分类
在执行数据分类时,需要决策树以及用于构造树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点;最后将测试数据定义为叶子节点所属的类型。故继续在DT.py中添加classify()函数:
# 存储带有特征的数据会面临一个问题:程序无法确定特征在数据集中的位置
# 使用index方法查找当前列表中第一个匹配firstStr变量的元素
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
# 将标签字符串转换为索引
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel # 返回遍历后当前节点的分类标签
接下来在main函数中调用函数:
但是发现了尴尬的问题,我之前在DT.py中写了main函数,在DTPlotter.py中也写了main函数,但测试算法时,我需要调用DT.py中的函数,也需要调用DTPlotter.py的函数,与其把DT.py导入到DTPlotter.py中,我在想还不如干脆想c++一样单独建一个main函数,导入DT.py和DTPlotter.py(类比c++导入头文件),试了果然可以。(果然编程能力也是在实战中进步的。我还真是个憨憨,尴尬)
接下来单独新建main.py:
import DT
import DTPlotter
if __name__ == '__main__':
myDat, labels = DT.createDataSet()
print(labels)
myTree = DTPlotter.retrieveTree(0)
print(myTree)
print(DT.classify(myTree, labels, [1,0]))
print(DT.classify(myTree, labels, [1,1]))
运行,输出结果:
输出结果与上节输出结果比较:第一节点名为:no surfacing,它有两个子节点:一个是名字为0的叶子节点,类标签为no;另一个是名为flippers的判断节点,此处进入递归调用,flippers节点有两个子节点。
3.2使用算法:决策树的存储
构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。所以为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用Python模块pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来,任何时候都可以执行序列化操作,字典对象也不列外。
在DT.py中添加:
# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb') # python2是 w,python3是 wb
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename,'rb') # 对应的打开时要用 rb
return pickle.load(fr)
然后在main.py中添加:
DT.storeTree(myTree, 'classifierStorage.txt')
DT.grabTree('classifierStorage.txt')
输出结果:
而且生成了classifierStorage.txt文件,这样,就不用每次对数据分类时重新学习一遍,这也是决策树的优点之一。
4.示例:使用决策树预测隐形眼镜类型
案例:眼科医生是如何判断患者需要佩戴的镜片类型;一旦理解了决策树的工作原理,也可以帮助人们判断需要佩戴的镜片类型。
示例:使用决策树预测隐形眼镜类型
- 收集数据:提供的文本文件
- 准备数据:解析tab键分隔的数据行
- 分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图。
- 训练算法:使用createTree()函数
- 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
- 使用算法:存储树的数据结构,以便下次使用时无需重新构建树
隐形眼镜数据集是非常著名的数据集,它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。(数据来源于:UCI机器学习存储库 https://archive.ics.uci.edu/ml/index.php )
继续在main.py中添加代码:
fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree = DT.createTree(lenses,lensesLabels)
print(lensesTree)
DTPlotter.createPlot(lensesTree)
输出结果:(树形图构造好像不对,尴尬)
有空仔细查看哪个函数写错了。
这次使用的算法称为ID3,它是一个好的算法但并不完美,之后会学另一个决策树构造算法CART。ID3算法无法直接处理数值型数据,尽管可以通过量化的方法将数值型数据转化为标称型数值,但如果存在太多的特征划分,ID3算法仍然会面临其它问题。
5.总结
决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,开始处理数据集时,需要先测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。ID3算法可以用于划分标称型数据集。构建决策树时,通常采用递归的方法将数据集转化为决策树。一般不构造新的数据结构,而是使用Python语言内嵌的数据结构字典存储树节点信息。
使用Matplotlib的注解功能,可以将存储的树结构转化为容易理解的图形。Python语言的pickle模块可用于存储决策树的结构。隐形眼镜的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题。
还有其它决策树的构造算法,最流行的是C4.5和CART。
参考文献:
- 《机器学习实战》-k近邻算法
革命尚未成功,同志仍需努力!