1 决策树的构造
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
缺点:可能产生过度匹配问题
适用数据类型:数值型和标称型
解决的首要问题:当前数据集上哪个特征在划分数据分类时起决定性作用
创建分支的伪代码函数createBranch():
检测数据集中的每个子项是否属于同一分类;
if so return 类标签;
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch并增加返回结果到分支节点中
return 分支节点
决策树的一般流程
- 收集数据
- 准备数据:只适用于标称型数据,数值型数据必须离散化
- 分析数据:检查图形是否符合预期
- 训练算法:构造树的数据结构
- 测试算法:使用经验树计算错误率
- 使用算法:可以适用于任何监督学习算法,使用决策树可以更好地理解数据的内在含义
1.1 信息增益
信息增益:划分数据集之后信息发生的变化
通过计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择
符号xi的信息定义:
I(xi)=-log2p(xi)
计算熵:
H=-∑i=1n\sum_{i=1}^{n}∑i=1np(xi)log2p(xi)
定义calcShannonEnt函数计算给定数据集的香农熵:
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 计算数据集中实例的总数
labelCounts = {} # 新建字典,记录每个分类下的数据个数
for featVec in dataSet: # 为所有可能的分类创建字典
currentLabel = featVec[-1] # 将dataSet每一个元素的最后一个元素选择出来
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 # 当没有该键时,使用字典的自动添加添加值为0的项
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries # 取概率
shannonEnt -= prob * log(prob, 2) # log(x,2)表示以2为底求x的对数
return shannonEnt
熵越高,则混合的数据也越多
得到熵后,就可以按照获得最大信息增益的方法划分数据集
注:另一个度量集合无序程度的方法是基尼不纯度:从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率
1.2 划分数据集
划分数据集:
def splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集。dataSet是待划分的数据集,axis是划分数据集的特征,value是需要返回的特征值
retDataSet = [] # 创建新的list对象(为了不修改原始数据集,数据集这个列表的各个元素也是列表)
for featVec in dataSet: # 将符合特征的数据抽取出来
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:]) # extend用于在列表末尾一次性追加另一个序列的多个值
retDataSet.append(reducedFeatVec) # append用于在列表末尾添加新的对象
return retDataSet
选择最好的数据集划分方式:
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 计算所有特征数
baseEntropy = calcShannonEnt(dataSet) # 计算原始香农熵,保存最初的无序度量值,用于与划分完之后的数据集计算的熵值进行比较
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # 遍历所有样本的第i个特征的取值情况(使用列表推导创建新的列表)
uniqueVals = set(featList) # 第i条特征的取值(去重) set函数用于创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可计算交集、差集、并集等
newEntropy = 0.0
for value in uniqueVals: # 计算每种划分方式的信息熵。对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
trees.chooseBestFeatureToSplit(myDat)
输出:0
1.3 递归构建决策树
需要注意的点:由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分
递归结束的条件:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类,则得到一个叶子结点或终止块
需要考虑的特殊情况:当数据集已经处理了所有的属性,但类标签依然不是唯一的,此时需要决定如何定义该叶子节点。通常采用多数表决的方法决定该叶子结点的分类。
def majorityCnt(classList):
classCount = {}#创建键值为classList中唯一值的数据字典,存储classList中每个类标签出现的频率,然后利用operator操作键值排序字典,返回出现次数最多的分类名称
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels): # 参数为数据集和标签列表 标签列表包含了数据集中所有特征的标签,为了给出数据明确的含义将其作为输入参数提供
classList = [example[-1] for example in dataSet] # 取标签值
# 第一个停止条件:所有的类标签完全相同,则直接返回该类标签
if classList.count(classList[0]) == len(classList): # count函数用于统计某个元素在列表中出现的次数
return classList[0] # 列表完全相同则停止继续划分
# 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,因此挑选出现次数最多的类别作为返回值
if len(dataSet[0]) == 1:
return majorityCnt(classList) # 遍历完所有特征时返回出现次数最多的
bestFeat = chooseBestFeatureToSplit(dataSet) # 当前数据集选取的最好特征
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}} # 用字典存储树的结构
del (labels[bestFeat]) # 删除取出过的标签,避免重复计算
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) # 得到列表包含的所有属性值,利用set去重
for value in uniqueVals:
subLabels = labels[:] # 复制所有的子标签,因为是引用类型,以避免改变原始标签数据
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) # 递归构建树
return myTree
myDat, labels = trees.createDataSet()
myTree = trees.createTree(myDat, labels)
print(myTree)
可以看出myTree包含了很多树结构信息的嵌套字典
第一个关键字是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典;第二个关键字是no surfacing特征划分的数据集,这些关键字的值是no surfacing节点的子节点:值是类标签时,说明该子节点是叶子节点;值是另一个数据字典时,则子节点是一个判断节点
2 使用matplotlib注解绘制数图形
matplotlib提供的注解工具:在数据图形上添加文本注释
2.1 使用文本注解绘制树节点
treePlotter.py:
import matplotlib.pyplot as plt
#定义树节点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策节点的属性。boxstyle为文本框的类型,sawtooth为锯齿形,fc为边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8") # 决策树叶子结点的属性
arrow_args = dict(arrowstyle="<-") # 剪头的属性
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 执行绘图功能。绘图区域由全局变量createPlot.ax1定义
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# plt.annotate(str, xy=data_point_position, xytext=annotate_position,
#va="center", ha="center", xycoords="axes fraction",
# textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)
# str是给数据点添加注释的内容,支持输入一个字符串
# xy=是要添加注释的数据点的位置
# xytext=是注释内容的位置
# bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
# va="center", ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
# xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同
# arrowprops可以指定箭头的风格支持,输入一个字典
# plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)
def createPlot(): # 代码核心。首先创建一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面用这两个结点绘制树图形
fig = plt.figure(1, facecolor='white') # 1表示图形编好/名称
fig.clf() # 表示清除所有轴
createPlot.ax1 = plt.subplot(111, frameon=False) # 为对象添加属性 frameon=true时图示被绘制在一个patch实体上;=false则图示直接被绘制在图形上
plotNode('a dicision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
2.2 构造注解树
def getNumLeafs(myTree): # 获取叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys(): # 测试节点的数据类型是否为字典
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree): # 获取树的层数
maxDepth = 0
firstStr = list(myTree.keys())[0] # keys()返回一个字典的所有键
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i): # 输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
Terminal:
print(treePlotter.retrieveTree(1))
myTree = treePlotter.retrieveTree(0)
print(treePlotter.getNumLeafs(myTree))
print(treePlotter.getTreeDepth(myTree))
def plotMidText(cntrPt, parentPt, txtString): # 用于计算父节点和子节点的中间位置,并在此添加简单的文本标签信息
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt): # 完成绘制树形的大多数工作
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 按照图形比例绘图,根据叶子结点的数目划分图形的宽度,从而计算得到当前结点的中心位置
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 按比例减少全局变量plotTree.yOff
# 由于是自顶向下绘制图形,因此需要依次递减y坐标值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 当节点不是叶子节点时递归调用plotTree
plotTree(secondDict[key], cntrPt, str(key))
else: # 当节点时叶子节点时在图形上画出叶子节点
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD # 在绘制了所有子节点之后,增加全局变量Y的偏移
def createPlot(inTree): # 创建绘图区,计算树形图的全局尺寸,并递归调用函数plotTree()
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # .ax1相当于对函数对象添加属性
plotTree.totalW = float(getNumLeafs(inTree)) # 全局变量plotTree.totalW存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) # 全局变量plotTree.totalD存储树的深度
plotTree.xOff = -0.5 / plotTree.totalW # 全局变量plotTree.xOff和plotTree.yOff用于追踪已经绘制的结点位置,以及放置下一个节点的恰当位置
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
myTree = treePlotter.retrieveTree(0)
treePlotter.createPlot(myTree)
myTree = treePlotter.retrieveTree(0)
myTree[‘no surfacing’][3] = ‘maybe’
treePlotter.createPlot(myTree)
3 测试和存储分类器
3.1 使用决策树执行分类
def classify(inputTree, featLabels, testVec):
# 在存储带有特征的数据时,程序无法确定特征在数据集中的位置,因此使用特征标签列表解决该问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素,然后代码递归遍历整棵树,比较testVec变量中的值域树节点的值,如果到达叶子节点则返回节点的分类标签
firstStr = list(inputTree.keys())[0]
secondDist = inputTree[firstStr]
featIndex = featLabels.index(firstStr) # 找到根特征在featLabels的位置,将标签字符串转换为索引
for key in secondDist.keys():
if testVec[featIndex] == key:
if type(secondDist[key]).__name__ == 'dict':
classLabel = classify(secondDist[key], featLabels, testVec)
else:
classLabel = secondDist[key]
return classLabel
myDat, labels = trees.createDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
print(myTree)
print(trees.classify(myTree, labels, [1, 0]))
print(trees.classify(myTree, labels, [1, 1]))
3.2 使用算法:决策树的存储
每次使用分类器时重新构造决策树是很耗时的任务。为了解决该问题需要使用python模块pickle序列化对象,可以在磁盘上保存对象,并在需要时读取出来。任何对象都可以执行序列化操作。
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb+')
pickle.dump(inputTree, fw) # pickle.dump(obj, file, [,protocol])将对象obj保存到file中.proctol为序列化使用的协议版本
fw.close()
def grabTree(filename):
import pickle
fr = open(filename,'rb+')
return pickle.load(fr) # 用于反序列化对象,将文件中的数据解析为一个python对象
myDat, labels = trees.createDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
trees.storeTree(myTree, ‘classifierStorage.txt’)
trees.grabTree(‘classifierStorage.txt’)
4 示例:使用决策树预测隐形眼镜类型
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)
沿着决策树的不同分支即可得到不同患者需要佩戴的隐形眼镜类型
附:全部代码
treePlotter.py:
import matplotlib.pyplot as plt
# 定义树节点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策节点的属性。boxstyle为文本框的类型,sawtooth为锯齿形,fc为边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8") # 决策树叶子结点的属性
arrow_args = dict(arrowstyle="<-") # 剪头的属性
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 执行绘图功能。绘图区域由全局变量createPlot.ax1定义
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# plt.annotate(str, xy=data_point_position, xytext=annotate_position,
# va="center", ha="center", xycoords="axes fraction",
# textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)
# str是给数据点添加注释的内容,支持输入一个字符串
# xy=是要添加注释的数据点的位置
# xytext=是注释内容的位置
# bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
# va="center", ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
# xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同
# arrowprops可以指定箭头的风格支持,输入一个字典
# plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)
def createPlot(): # 代码核心。首先创建一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面用这两个结点绘制树图形
fig = plt.figure(1, facecolor='white') # 1表示图形编好/名称
fig.clf() # 表示清除所有轴
createPlot.ax1 = plt.subplot(111, frameon=False) # 为对象添加属性 frameon=true时图示被绘制在一个patch实体上;=false则图示直接被绘制在图形上
plotNode('a dicision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def getNumLeafs(myTree): # 获取叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys(): # 测试节点的数据类型是否为字典
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree): # 获取树的层数
maxDepth = 0
firstStr = list(myTree.keys())[0] # keys()返回一个字典的所有键
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i): # 输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flipppers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flipppers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString): # 用于计算父节点和子节点的中间位置,并在此添加简单的文本标签信息
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt): # 完成绘制树形的大多数工作
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 按照图形比例绘图,根据叶子结点的数目划分图形的宽度,从而计算得到当前结点的中心位置
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 按比例减少全局变量plotTree.yOff
# 由于是自顶向下绘制图形,因此需要依次递减y坐标值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 当节点不是叶子节点时递归调用plotTree
plotTree(secondDict[key], cntrPt, str(key))
else: # 当节点时叶子节点时在图形上画出叶子节点
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD # 在绘制了所有子节点之后,增加全局变量Y的偏移
def createPlot(inTree): # 创建绘图区,计算树形图的全局尺寸,并递归调用函数plotTree()
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # .ax1相当于对函数对象添加属性
plotTree.totalW = float(getNumLeafs(inTree)) # 全局变量plotTree.totalW存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) # 全局变量plotTree.totalD存储树的深度
plotTree.xOff = -0.5 / plotTree.totalW # 全局变量plotTree.xOff和plotTree.yOff用于追踪已经绘制的结点位置,以及放置下一个节点的恰当位置
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
trees.py:
from math import log
import operator
def createDataSet():
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flipppers']
return dataSet, labels
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 计算数据集中实例的总数
labelCounts = {} # 新建字典,记录每个分类下的数据个数
for featVec in dataSet: # 为所有可能的分类创建字典
currentLabel = featVec[-1] # 将dataSet每一个元素的最后一个元素选择出来
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 # 当没有该键时,使用字典的自动添加添加值为0的项
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries # 取概率
shannonEnt -= prob * log(prob, 2) # log(x,2)表示以2为底求x的对数
return shannonEnt
def splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集。dataSet是待划分的数据集,axis是划分数据集的特征,value是需要返回的特征值
retDataSet = [] # 创建新的list对象(为了不修改原始数据集,数据集这个列表的各个元素也是列表)
for featVec in dataSet: # 将符合特征的数据抽取出来
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:]) # extend用于在列表末尾一次性追加另一个序列的多个值
retDataSet.append(reducedFeatVec) # append用于在列表末尾添加新的对象
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 计算所有特征数
baseEntropy = calcShannonEnt(dataSet) # 计算原始香农熵,保存最初的无序度量值,用于与划分完之后的数据集计算的熵值进行比较
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # 遍历所有样本的第i个特征的取值情况(使用列表推导创建新的列表)
uniqueVals = set(featList) # 第i条特征的取值(去重) set函数用于创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可计算交集、差集、并集等
newEntropy = 0.0
for value in uniqueVals: # 计算每种划分方式的信息熵。对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {} # 创建键值为classList中唯一值的数据字典,存储classList中每个类标签出现的频率,然后利用operator操作键值排序字典,返回出现次数最多的分类名称
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels): # 参数为数据集和标签列表 标签列表包含了数据集中所有特征的标签,为了给出数据明确的含义将其作为输入参数提供
classList = [example[-1] for example in dataSet] # 取标签值
# 第一个停止条件:所有的类标签完全相同,则直接返回该类标签
if classList.count(classList[0]) == len(classList): # count函数用于统计某个元素在列表中出现的次数
return classList[0] # 列表完全相同则停止继续划分
# 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,因此挑选出现次数最多的类别作为返回值
if len(dataSet[0]) == 1:
return majorityCnt(classList) # 遍历完所有特征时返回出现次数最多的
bestFeat = chooseBestFeatureToSplit(dataSet) # 当前数据集选取的最好特征
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}} # 用字典存储树的结构
del (labels[bestFeat]) # 删除取出过的标签,避免重复计算
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) # 得到列表包含的所有属性值,利用set去重
for value in uniqueVals:
subLabels = labels[:] # 复制所有的子标签,因为是引用类型,以避免改变原始标签数据
# 在python中函数参数是列表类型时,参数是按照引用方式传递。为了保证每次调用函数createTree()时不改变原始列表的内容,使用新变量subLabels代替原始列表
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) # 递归构建树
return myTree
def classify(inputTree, featLabels, testVec):
# 在存储带有特征的数据时,程序无法确定特征在数据集中的位置,因此使用特征标签列表解决该问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素,然后代码递归遍历整棵树,比较testVec变量中的值域树节点的值,如果到达叶子节点则返回节点的分类标签
firstStr = list(inputTree.keys())[0]
secondDist = inputTree[firstStr]
featIndex = featLabels.index(firstStr) # 找到根特征在featLabels的位置,将标签字符串转换为索引
for key in secondDist.keys():
if testVec[featIndex] == key:
if type(secondDist[key]).__name__ == 'dict':
classLabel = classify(secondDist[key], featLabels, testVec)
else:
classLabel = secondDist[key]
return classLabel
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb+')
pickle.dump(inputTree, fw) # pickle.dump(obj, file, [,protocol])将对象obj保存到file中.proctol为序列化使用的协议版本
fw.close()
def grabTree(filename):
import pickle
fr = open(filename,'rb+')
return pickle.load(fr) # 用于反序列化对象,将文件中的数据解析为一个python对象
main.py:
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)