决策树:ID3算法

本文介绍了如何利用信息熵(香农熵)和信息增益来构建决策树。通过示例展示了如何计算数据集的熵,划分数据集,并选择最佳划分方式。接着,详细阐述了创建决策树的函数代码、绘制树形图的方法,以及如何利用决策树进行分类和存储。文章最后提及了实际应用案例和对未来学习的展望。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

组织杂乱无章的数据的一种方法就是使用信息论度量信息
在划分数据集前后信息发生的变化成为信息增益
集合信息的度量方式称之为香农熵
也就是说可以通过香农熵的变化来体现信息增益
计算所有类别的信息期望值(熵) 公式为
这里写图片描述
单个数据的信息期望值为这里写图片描述

样本表

是否必须生活在水里是否有脚蹼是否为鱼类(类别标签)
11yes
11yes
10no
01no
01no
IF 数据集中的每个元素都属于同一分类:
    RETURN 类别标签
ELIF 数据集的长度为1
    RETURN 出现次数最多的类别
ELSE
    寻找划分数据集的最好特征
        FOR 每个特征
            FOR 该特征中不重复的值
                计算每个特征值对应的数据熵
                累加特征值的熵
            得到最小的熵所在的特征
        RETURN 最小的熵的特征
    划分数据集
        将最优特征作为当前的根节点
        在特征列表中去掉最优特征
        FOR 每个划分的子集
            调用自身(迭代)并返回结果到分支节点中
    RETURN 分支节点
  1. 计算给定数据集的香农熵
from math import log
import numpy as np

def calShannonEnt(dataSet):
    numEntries = len(dataSet) #获取数据集的数据条数,也就是行数
    labelCounts = {} #创建字典,盛放类别名称和出现的次数
    for featVec in dataSet: #遍历数据集
        currentLabel = featVec[-1] #获取每条数据的类别名称
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0 #初始化单条数据的香农熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries #获取该类别占总数的比例
        shannonEnt -= prob * log(prob, 2)  #用代码书写计算香农熵的公式
    return shannonEnt
#测试数据
def createDataSet(): 
    dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels = ['no surfacing', 'filippers']
    return dataSet, labels

测试
dataSet, labels = createDataSet()
calShannonEnt(dataSet)

结果为 0.9709505944546686

改变数据

dataSet[0][-1] = 'may be'
calShannonEnt(dataSet)

结果为 1.3709505944546687
结论:
混合的数据越多,熵越高。

2.划分数据集:按照给定的特征划分数据集

#判断对应下标axis的数组元素是否为value,将所有满足的数据经过一定的调整放入一个列表中展示出来
def splitDataSet(dataSet, axis, value): #数据集,下表,值
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value: #判断数组下标的元素是否为输入的value
            reducedFeatVec = featVec[:axis] #获取数组下标前的元素
            reducedFeatVec.extend(featVec[axis+1 :])#将数组下标后的元素组成的列表添加到之前的列表中
            retDataSet.append(reducedFeatVec) #将获取到的列表放到另一个列表中
    return retDataSet

测试splitDataSet(dataSet, 2, 'no')
结果[[1, 0], [0, 1], [0, 1]]

3. 选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 #获取数据的长度并减一(最后一列为类别标签,剩下的都是特征)
    baseEntropy = calShannonEnt(dataSet) #获取数据集的熵
    bestInfoGain = 0.0 #最优的信息增益
    bestFeature = -1  #最优的特征所在的下标
    for i in range(numFeatures): #遍历数据的特征值
        featList = [example[i] for example in dataSet] #遍历数据集,获取每组数据的第i个特征值,组成列表,注意外边的中括号
        uniqueVals = set(featList) #将列表转换为集合,达到去重的效果
        newEntropy = 0.0 #初始化香农熵
        for value in uniqueVals: #遍历去重后的数据集中的第i列的值,获取第i列的香农熵的总和
            subDataSet = splitDataSet(dataSet, i, value) #调用‘按照给定特征划分数据集’的函数,其中i为value在列表中的下标,返回结果为符合条件的列表
            prob = len(subDataSet) / float(len(dataSet)) #符合条件的列表的长度 除以 总数据集的长度,获得符合条件的数据的占比
            newEntropy += prob * calShannonEnt(subDataSet) #对符合条件的列表进行香农熵计算,结果乘以符合条件的数据占比。并累加得到新的香农熵
        infoGain = baseEntropy - newEntropy #用总数据集的香农熵减去 新获得的香农熵 得到信息增益,信息增益即是熵的减少

        if infoGain > bestInfoGain: #如果信息增益大于最优信息增益,将信息增益的值赋给最优信息增益,最优特征的值变为i,即本次循环的下标
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

测试 dataSet = [[1,1,'yes'],[1,0,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
chooseBestFeatureToSplit(dataSet)

结果 0
测试 dataSet = [[1, 1, 1, 1, 'yes'], [1, 1, 2, 2, 'yes'], [1, 1, 3, 3, 'no'], [0, 0, 4, 4, 'no'], [0, 0, 5, 5, 'no']]
chooseBestFeatureToSplit(dataSet)

结果 2
结论:最好的数据集划分方式,即时通过最少的时间(步数)来查找到每个值,因此熵最大的特征就最好的数据集划分方式

4. 获取出现次数最多的分类名称

def majorityCnt(classList): #classList为分类名称的列表
    classCount = {} #创建键值为classList中唯一值的数据字典,字典对象存储了classList中每个类标签出现的频率
    for vote in classList:
        if vote not in classCount.keys():
            classCount[bote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse = True) #用operator操作键值排序字典
    return sortedClassCount[0][0] #返回出现次数最多的分类名称

5. 创建树的函数代码
递归结束的条件:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类
决策树的函数代码类似于json格式,键为节点名称,值为节点下的总信息

def createTree(dataSet, labels): #labels为数据集中所有的特征的标签
    classList = [example[-1] for example in dataSet] #获取所有的分类名称
    if classList.count(classList[0]) == len(classList): #类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类别
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet) #返回划分数据集的最优特征的索引
    bestFeatLabel = labels[bestFeat] #返回划分数据集的最优特征
    myTree = {bestFeatLabel:{}} #创建决策树‘对象’
    del(labels[bestFeat]) #在类别的列表中 删除使用过的特征
    featValues = [example[bestFeat] for example in dataSet] #得到刚刚获取的最优特征对应的值并存放在列表中
    uniqueVals = set(featValues) #去重
    for value in uniqueVals:
        subLabels = labels[:] #复制labels,保证每次调用myTree()时,不改变原始列表的内容
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

测试
dataSet, labels = createDataSet()
myTree = createTree(dataSet,labels)
print(myTree)

结果 {'no surfacing': {0: 'no', 1: {'filippers': {0: 'no', 1: 'yes'}}}}
6. 绘制简单的树形图

import matplotlib.pyplot as plt
import matplotlib
myfont = matplotlib.font_manager.FontProperties(fname='C:\Windows\Fonts\msyh.ttc')
#解决负号'-'显示为方块的问题 
matplotlib.rcParams['axes.unicode_minus']=False
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction', va='center', ha='center',bbox=nodeType,arrowprops=arrow_args,fontproperties=myfont)
def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.axl = plt.subplot(111, frameon=False)
    plotNode('决策树节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

运行 createPlot()
结果这里写图片描述

7. 获取叶节点的数目(树的宽度)和树的层数(高度)

def getNumLeafs(myTree):
    numLeafs = 0
    dict_firstStr = list(myTree.keys())
    firstStr =dict_firstStr[0] #根节点的名称
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key]) #递归的正确用法
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDir = myTree[firstStr]
    for key in secondDir.keys():
        if type(secondDir[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDir[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: 
            maxDepth = thisDepth
    return maxDepth

8. 绘制需要的树形图

#计算子节点和父节点的中间位置,并添加简单的文本标签信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree) #计算树的宽度
    depth = getTreeDepth(myTree) #计算树的高度
    firstStr = list(myTree.keys())[0] #获取字典myTree的键值,并获取第一个值
    #plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置 plotTree.totalD和plotTree.totalW全局变量记录树的宽度和深度
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) #下个子节点的中心店(0.5, 1.0) 
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr] #根节点下的所有节点信息
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y方向偏移 (自上向下)
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict': #类型为dict 说明还有子节点,递归
            plotTree(secondDict[key], cntrPt, str(key))
        else:  #类型不是dict,说明只剩下叶子节点
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #X偏移
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #绘制完子节点后,增加全局变量Y的偏移

# 绘制图区,计算树图形的全局尺寸,并调用递归函数plotTree
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.axl = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

测试 dataSet, labels = createDataSet()
myTree = createTree(dataSet,labels)
createPlot(myTree)

结果
这里写图片描述

9. 使用决策树的分类函数

输入获得的决策树代码,决策树的标签向量和对应的特征值,得到理想的类别

def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0] #返回根节点
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr) #index()查找当前列表中第一个匹配firstStr变量的元素 的索引
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

测试
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'filippers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'filippers': {0: 'no', 1: 'yes'}}}},]
return listOfTrees[i]
myDat, labels = createDataSet()
myTree = retrieveTree(0)
classify(myTree, labels, [1,1])

结果 yes
其中 labels ['no surfacing', 'filippers']
myTree {'no surfacing': {0: 'no', 1: {'filippers': {0: 'no', 1: 'yes'}}}}

10. 使用pickle模块存储决策树
python的模块pickle可以序列化对象,序列化的对象可以保存到硬盘上,并在需要的时候读取出来

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb') #以二进制存储 b
    pickle.dump(inputTree, fw)
def grabTree(filename):
    import pickle
    fr = open(filename,'rb')
    return pickle.load(fr)

测试 storeTree(myTree, 'classifierStorage.txt')
grabTree('classifierStorage.txt')

结果{'no surfacing': {0: 'no', 1: {'filippers': {0: 'no', 1: 'yes'}}}}

11. 实际应用:使用决策树预测隐形眼镜类型

fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['ages','prescript','astigmatic','tearRate']
lensesTree = creatTree(lenses,lensesLabels)
print(lensesTree)
createPlot(lensesTree)

结果
这里写图片描述

总结
对matplotlib的使用上还很模糊,需要认真看一些资料。有些代码部分还没有完全搞懂,有时间多看多想,完善代码逻辑。

最后
希望大家多多指正!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值