机器学习实战|决策树

本文深入探讨了决策树的构建过程,包括信息增益、数据集划分、递归构建决策树等关键步骤。同时,介绍了使用matplotlib库进行决策树可视化的方法,以及如何使用决策树进行分类和存储。此外,还展示了决策树在预测隐形眼镜类型等实际问题中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 决策树的构造

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
缺点:可能产生过度匹配问题
适用数据类型:数值型和标称型

解决的首要问题:当前数据集上哪个特征在划分数据分类时起决定性作用
创建分支的伪代码函数createBranch():

检测数据集中的每个子项是否属于同一分类;
if so return 类标签;
Else
	寻找划分数据集的最好特征
	划分数据集
	创建分支节点
		for 每个划分的子集
			调用函数createBranch并增加返回结果到分支节点中
	return 分支节点

决策树的一般流程

  1. 收集数据
  2. 准备数据:只适用于标称型数据,数值型数据必须离散化
  3. 分析数据:检查图形是否符合预期
  4. 训练算法:构造树的数据结构
  5. 测试算法:使用经验树计算错误率
  6. 使用算法:可以适用于任何监督学习算法,使用决策树可以更好地理解数据的内在含义

1.1 信息增益

信息增益:划分数据集之后信息发生的变化
通过计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择
符号xi的信息定义:
I(xi)=-log2p(xi)
计算熵:
H=-∑i=1n\sum_{i=1}^{n}i=1np(xi)log2p(xi)
定义calcShannonEnt函数计算给定数据集的香农熵:

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  # 计算数据集中实例的总数
    labelCounts = {}  # 新建字典,记录每个分类下的数据个数
    for featVec in dataSet:  # 为所有可能的分类创建字典
        currentLabel = featVec[-1]  # 将dataSet每一个元素的最后一个元素选择出来
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0  # 当没有该键时,使用字典的自动添加添加值为0的项
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries  # 取概率
        shannonEnt -= prob * log(prob, 2)  # log(x,2)表示以2为底求x的对数
    return shannonEnt

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
熵越高,则混合的数据也越多
在这里插入图片描述在这里插入图片描述
得到熵后,就可以按照获得最大信息增益的方法划分数据集

注:另一个度量集合无序程度的方法是基尼不纯度:从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率

1.2 划分数据集

划分数据集:

def splitDataSet(dataSet, axis, value):  # 按照给定特征划分数据集。dataSet是待划分的数据集,axis是划分数据集的特征,value是需要返回的特征值
    retDataSet = []  # 创建新的list对象(为了不修改原始数据集,数据集这个列表的各个元素也是列表)
    for featVec in dataSet:  # 将符合特征的数据抽取出来
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])  # extend用于在列表末尾一次性追加另一个序列的多个值
            retDataSet.append(reducedFeatVec)  # append用于在列表末尾添加新的对象
    return retDataSet

在这里插入图片描述
在这里插入图片描述
选择最好的数据集划分方式:

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 计算所有特征数
    baseEntropy = calcShannonEnt(dataSet)  # 计算原始香农熵,保存最初的无序度量值,用于与划分完之后的数据集计算的熵值进行比较
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 遍历所有样本的第i个特征的取值情况(使用列表推导创建新的列表)
        uniqueVals = set(featList)  # 第i条特征的取值(去重)   set函数用于创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可计算交集、差集、并集等
        newEntropy = 0.0
        for value in uniqueVals:  # 计算每种划分方式的信息熵。对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

trees.chooseBestFeatureToSplit(myDat)
输出:0

1.3 递归构建决策树

需要注意的点:由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分
递归结束的条件:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类,则得到一个叶子结点或终止块
需要考虑的特殊情况:当数据集已经处理了所有的属性,但类标签依然不是唯一的,此时需要决定如何定义该叶子节点。通常采用多数表决的方法决定该叶子结点的分类。

def majorityCnt(classList):
    classCount = {}#创建键值为classList中唯一值的数据字典,存储classList中每个类标签出现的频率,然后利用operator操作键值排序字典,返回出现次数最多的分类名称
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
def createTree(dataSet, labels):  # 参数为数据集和标签列表 标签列表包含了数据集中所有特征的标签,为了给出数据明确的含义将其作为输入参数提供
    classList = [example[-1] for example in dataSet]  # 取标签值
    # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签
    if classList.count(classList[0]) == len(classList):  # count函数用于统计某个元素在列表中出现的次数
        return classList[0]  # 列表完全相同则停止继续划分
    # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,因此挑选出现次数最多的类别作为返回值
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)  # 遍历完所有特征时返回出现次数最多的
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 当前数据集选取的最好特征
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}  # 用字典存储树的结构
    del (labels[bestFeat])  # 删除取出过的标签,避免重复计算
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)  # 得到列表包含的所有属性值,利用set去重
    for value in uniqueVals:
        subLabels = labels[:]  # 复制所有的子标签,因为是引用类型,以避免改变原始标签数据
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  # 递归构建树
    return myTree

myDat, labels = trees.createDataSet()
myTree = trees.createTree(myDat, labels)
print(myTree)

在这里插入图片描述
可以看出myTree包含了很多树结构信息的嵌套字典
第一个关键字是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典;第二个关键字是no surfacing特征划分的数据集,这些关键字的值是no surfacing节点的子节点:值是类标签时,说明该子节点是叶子节点;值是另一个数据字典时,则子节点是一个判断节点

2 使用matplotlib注解绘制数图形

matplotlib提供的注解工具:在数据图形上添加文本注释

2.1 使用文本注解绘制树节点

treePlotter.py:

import matplotlib.pyplot as plt

#定义树节点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 决策节点的属性。boxstyle为文本框的类型,sawtooth为锯齿形,fc为边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8")  # 决策树叶子结点的属性
arrow_args = dict(arrowstyle="<-")  # 剪头的属性


def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 执行绘图功能。绘图区域由全局变量createPlot.ax1定义
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
                
 # plt.annotate(str, xy=data_point_position, xytext=annotate_position,
#va="center",  ha="center", xycoords="axes fraction",
 #              textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)
 # str是给数据点添加注释的内容,支持输入一个字符串
 # xy=是要添加注释的数据点的位置
 # xytext=是注释内容的位置
 # bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
 # va="center",  ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
 # xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同
 # arrowprops可以指定箭头的风格支持,输入一个字典
 # plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)

def createPlot():  # 代码核心。首先创建一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面用这两个结点绘制树图形
    fig = plt.figure(1, facecolor='white')  # 1表示图形编好/名称
    fig.clf()  # 表示清除所有轴
    createPlot.ax1 = plt.subplot(111, frameon=False)  # 为对象添加属性  frameon=true时图示被绘制在一个patch实体上;=false则图示直接被绘制在图形上
    plotNode('a dicision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

在这里插入图片描述

2.2 构造注解树

def getNumLeafs(myTree):  # 获取叶节点的数目
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():  # 测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):  # 获取树的层数
    maxDepth = 0
    firstStr = list(myTree.keys())[0]  # keys()返回一个字典的所有键
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


def retrieveTree(i):  # 输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

Terminal:

print(treePlotter.retrieveTree(1))
myTree = treePlotter.retrieveTree(0)
print(treePlotter.getNumLeafs(myTree))
print(treePlotter.getTreeDepth(myTree))

在这里插入图片描述

def plotMidText(cntrPt, parentPt, txtString):  # 用于计算父节点和子节点的中间位置,并在此添加简单的文本标签信息
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):  # 完成绘制树形的大多数工作
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # 按照图形比例绘图,根据叶子结点的数目划分图形的宽度,从而计算得到当前结点的中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 按比例减少全局变量plotTree.yOff
    # 由于是自顶向下绘制图形,因此需要依次递减y坐标值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 当节点不是叶子节点时递归调用plotTree
            plotTree(secondDict[key], cntrPt, str(key))
        else:  # 当节点时叶子节点时在图形上画出叶子节点
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD  # 在绘制了所有子节点之后,增加全局变量Y的偏移


def createPlot(inTree):  # 创建绘图区,计算树形图的全局尺寸,并递归调用函数plotTree()
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # .ax1相当于对函数对象添加属性
    plotTree.totalW = float(getNumLeafs(inTree))  # 全局变量plotTree.totalW存储树的宽度
    plotTree.totalD = float(getTreeDepth(inTree))  # 全局变量plotTree.totalD存储树的深度
    plotTree.xOff = -0.5 / plotTree.totalW  # 全局变量plotTree.xOff和plotTree.yOff用于追踪已经绘制的结点位置,以及放置下一个节点的恰当位置
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

myTree = treePlotter.retrieveTree(0)
treePlotter.createPlot(myTree)
在这里插入图片描述

myTree = treePlotter.retrieveTree(0)
myTree[‘no surfacing’][3] = ‘maybe’
treePlotter.createPlot(myTree)

在这里插入图片描述

3 测试和存储分类器

3.1 使用决策树执行分类

def classify(inputTree, featLabels, testVec):
    # 在存储带有特征的数据时,程序无法确定特征在数据集中的位置,因此使用特征标签列表解决该问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素,然后代码递归遍历整棵树,比较testVec变量中的值域树节点的值,如果到达叶子节点则返回节点的分类标签
    firstStr = list(inputTree.keys())[0]
    secondDist = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)  # 找到根特征在featLabels的位置,将标签字符串转换为索引
    for key in secondDist.keys():
        if testVec[featIndex] == key:
            if type(secondDist[key]).__name__ == 'dict':
                classLabel = classify(secondDist[key], featLabels, testVec)
            else:
                classLabel = secondDist[key]
    return classLabel

myDat, labels = trees.createDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
print(myTree)
print(trees.classify(myTree, labels, [1, 0]))
print(trees.classify(myTree, labels, [1, 1]))

在这里插入图片描述

3.2 使用算法:决策树的存储

每次使用分类器时重新构造决策树是很耗时的任务。为了解决该问题需要使用python模块pickle序列化对象,可以在磁盘上保存对象,并在需要时读取出来。任何对象都可以执行序列化操作。

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb+')
    pickle.dump(inputTree, fw)  # pickle.dump(obj, file, [,protocol])将对象obj保存到file中.proctol为序列化使用的协议版本
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename,'rb+')
    return pickle.load(fr)  # 用于反序列化对象,将文件中的数据解析为一个python对象

myDat, labels = trees.createDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
trees.storeTree(myTree, ‘classifierStorage.txt’)
trees.grabTree(‘classifierStorage.txt’)

在这里插入图片描述

4 示例:使用决策树预测隐形眼镜类型

在这里插入图片描述

fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = trees.createTree(lenses, lensesLabels)
    print(lensesTree)
    treePlotter.createPlot(lensesTree)

在这里插入图片描述
沿着决策树的不同分支即可得到不同患者需要佩戴的隐形眼镜类型

附:全部代码

treePlotter.py:

import matplotlib.pyplot as plt

# 定义树节点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 决策节点的属性。boxstyle为文本框的类型,sawtooth为锯齿形,fc为边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8")  # 决策树叶子结点的属性
arrow_args = dict(arrowstyle="<-")  # 剪头的属性


def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 执行绘图功能。绘图区域由全局变量createPlot.ax1定义
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# plt.annotate(str, xy=data_point_position, xytext=annotate_position,
#              va="center",  ha="center", xycoords="axes fraction",
#              textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)
# str是给数据点添加注释的内容,支持输入一个字符串
# xy=是要添加注释的数据点的位置
# xytext=是注释内容的位置
# bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
# va="center",  ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
# xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同
# arrowprops可以指定箭头的风格支持,输入一个字典
# plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)


def createPlot():  # 代码核心。首先创建一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面用这两个结点绘制树图形
    fig = plt.figure(1, facecolor='white')  # 1表示图形编好/名称
    fig.clf()  # 表示清除所有轴
    createPlot.ax1 = plt.subplot(111, frameon=False)  # 为对象添加属性  frameon=true时图示被绘制在一个patch实体上;=false则图示直接被绘制在图形上
    plotNode('a dicision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()


def getNumLeafs(myTree):  # 获取叶节点的数目
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():  # 测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):  # 获取树的层数
    maxDepth = 0
    firstStr = list(myTree.keys())[0]  # keys()返回一个字典的所有键
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


def retrieveTree(i):  # 输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flipppers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flipppers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]


def plotMidText(cntrPt, parentPt, txtString):  # 用于计算父节点和子节点的中间位置,并在此添加简单的文本标签信息
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):  # 完成绘制树形的大多数工作
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # 按照图形比例绘图,根据叶子结点的数目划分图形的宽度,从而计算得到当前结点的中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 按比例减少全局变量plotTree.yOff
    # 由于是自顶向下绘制图形,因此需要依次递减y坐标值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 当节点不是叶子节点时递归调用plotTree
            plotTree(secondDict[key], cntrPt, str(key))
        else:  # 当节点时叶子节点时在图形上画出叶子节点
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD  # 在绘制了所有子节点之后,增加全局变量Y的偏移


def createPlot(inTree):  # 创建绘图区,计算树形图的全局尺寸,并递归调用函数plotTree()
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # .ax1相当于对函数对象添加属性
    plotTree.totalW = float(getNumLeafs(inTree))  # 全局变量plotTree.totalW存储树的宽度
    plotTree.totalD = float(getTreeDepth(inTree))  # 全局变量plotTree.totalD存储树的深度
    plotTree.xOff = -0.5 / plotTree.totalW  # 全局变量plotTree.xOff和plotTree.yOff用于追踪已经绘制的结点位置,以及放置下一个节点的恰当位置
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

trees.py:

from math import log
import operator


def createDataSet():
    dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    labels = ['no surfacing', 'flipppers']
    return dataSet, labels


def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  # 计算数据集中实例的总数
    labelCounts = {}  # 新建字典,记录每个分类下的数据个数
    for featVec in dataSet:  # 为所有可能的分类创建字典
        currentLabel = featVec[-1]  # 将dataSet每一个元素的最后一个元素选择出来
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0  # 当没有该键时,使用字典的自动添加添加值为0的项
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries  # 取概率
        shannonEnt -= prob * log(prob, 2)  # log(x,2)表示以2为底求x的对数
    return shannonEnt


def splitDataSet(dataSet, axis, value):  # 按照给定特征划分数据集。dataSet是待划分的数据集,axis是划分数据集的特征,value是需要返回的特征值
    retDataSet = []  # 创建新的list对象(为了不修改原始数据集,数据集这个列表的各个元素也是列表)
    for featVec in dataSet:  # 将符合特征的数据抽取出来
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])  # extend用于在列表末尾一次性追加另一个序列的多个值
            retDataSet.append(reducedFeatVec)  # append用于在列表末尾添加新的对象
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 计算所有特征数
    baseEntropy = calcShannonEnt(dataSet)  # 计算原始香农熵,保存最初的无序度量值,用于与划分完之后的数据集计算的熵值进行比较
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 遍历所有样本的第i个特征的取值情况(使用列表推导创建新的列表)
        uniqueVals = set(featList)  # 第i条特征的取值(去重)   set函数用于创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可计算交集、差集、并集等
        newEntropy = 0.0
        for value in uniqueVals:  # 计算每种划分方式的信息熵。对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {}  # 创建键值为classList中唯一值的数据字典,存储classList中每个类标签出现的频率,然后利用operator操作键值排序字典,返回出现次数最多的分类名称
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):  # 参数为数据集和标签列表 标签列表包含了数据集中所有特征的标签,为了给出数据明确的含义将其作为输入参数提供
    classList = [example[-1] for example in dataSet]  # 取标签值
    # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签
    if classList.count(classList[0]) == len(classList):  # count函数用于统计某个元素在列表中出现的次数
        return classList[0]  # 列表完全相同则停止继续划分
    # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,因此挑选出现次数最多的类别作为返回值
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)  # 遍历完所有特征时返回出现次数最多的
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 当前数据集选取的最好特征
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}  # 用字典存储树的结构
    del (labels[bestFeat])  # 删除取出过的标签,避免重复计算
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)  # 得到列表包含的所有属性值,利用set去重
    for value in uniqueVals:
        subLabels = labels[:]  # 复制所有的子标签,因为是引用类型,以避免改变原始标签数据
        # 在python中函数参数是列表类型时,参数是按照引用方式传递。为了保证每次调用函数createTree()时不改变原始列表的内容,使用新变量subLabels代替原始列表
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  # 递归构建树
    return myTree


def classify(inputTree, featLabels, testVec):
    # 在存储带有特征的数据时,程序无法确定特征在数据集中的位置,因此使用特征标签列表解决该问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素,然后代码递归遍历整棵树,比较testVec变量中的值域树节点的值,如果到达叶子节点则返回节点的分类标签
    firstStr = list(inputTree.keys())[0]
    secondDist = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)  # 找到根特征在featLabels的位置,将标签字符串转换为索引
    for key in secondDist.keys():
        if testVec[featIndex] == key:
            if type(secondDist[key]).__name__ == 'dict':
                classLabel = classify(secondDist[key], featLabels, testVec)
            else:
                classLabel = secondDist[key]
    return classLabel


def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb+')
    pickle.dump(inputTree, fw)  # pickle.dump(obj, file, [,protocol])将对象obj保存到file中.proctol为序列化使用的协议版本
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename,'rb+')
    return pickle.load(fr)  # 用于反序列化对象,将文件中的数据解析为一个python对象

main.py:

 fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = trees.createTree(lenses, lensesLabels)
    print(lensesTree)
    treePlotter.createPlot(lensesTree)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值