《机器学习实战》第三章——决策树(ID3)

理论

寻找最具影响力的特征先进行判别,像一棵树一样的判断分支再判断分支,知道最后判别出属于哪个类别

优点
  • 计算复杂度不高
  • 输出结果易于理解,可以看出内在含义
  • 对缺失值不敏感
  • 可以处理不相关特征数据
缺点
  • 易产生过拟合问题
适用于
  • 离散型数据
  • 连续型数据需要离散化
总结:
划分数据集-按照信息增益

信息增益最高的特征就是最好的选择

信息增益:划分数据集之前和之后发生的信息变化

度量信息的单位称为熵(entropy),即信息的期望值,用于度量集合的无序程度(另一方法:基尼不纯度)

信息熵公式:
H = − ∑ i = 1 n p ( x i ) log ⁡ 2 p ( x i ) H = -\sum_{i=1}^{n}p(x_{i})\log_{2}p(x_{i}) H=i=1np(xi)log2p(xi)

其中 p ( x i ) p(x_{i}) p(xi) 是选择该分类的概率, n n n是分类的数目

例子
dataSet = 有两个分类yes和no ,共5 个样本,其中yes 2个,no 3个
H = -(0.4log(0.4,2)+0.6log(0.6,2)) = 0.971

from math import log

def calcShannonEnt(dataSet):
    '''
    计算信息熵
    '''
    numEntries = len(dataSet)
    labelCounts = {}  # 放各类别概率
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labeCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
        # labelCounts[currentLabel] =labelCounts.get(currentLabel, 0) + 1
    shannonEnt = 0.0
    for key in labeCounts:
        prob = float(labelCounts[key])/numEntries  # 每个类别出现的概率
        shannonEnt -= prob*log(prob, 2)
    return shannonEnt
信息增益

特征A对训练集数据D的信息增益 g ( D , A ) g(D, A) g(D,A),定义为集合D的信息熵 H ( D ) H(D) H(D)与特征A给定条件下D的信息熵 H ( D ∣ A ) H(D|A) H(DA)之差,即公式为:
g ( D , A ) = H ( D ) − H ( D ∣ A ) g(D,A) = H(D) - H(D|A) g(D,A)=H(D)H(DA)

信息熵和信息增益的计算

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
    '''
    dataSet : 数据集
    axis : 特征的索引
    value : 特征的值
    按照给定特征划分数据集
    取出每一个axis位置上为value的样本
    '''
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            # list.extend和list.append 的功能类似,anppend是以整个为一个元素添加,extend是可以以其中每个元素为元素添加
            retDataSet.append(reducedFeatVec)
    return retDataSet
选取“最优”特征,划分数据集
def chooseBestFeatureToSplit(dataSet):
    '''
    选取“最优”特征,划分数据集
    dataSet : 训练数据集,二维列表构成,每一行的最后一个元素为标签
    返回最优特征的索引
    '''
    numFeatures = len(dataSet[0]) - 1   # 获取特征数量
    baseEntropy = calcShannonEnt(dataSet)   # 计算所有样本信息熵
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 提取所有样本的某一特征
        uniqueVals = set(featList)   # 变成集合,去重
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i , value)   # 划分样本
            prob = len(subDataSet)/float(len(dataSet))  # 计算概率
            newEntropy += prob * calcShannonEnt(subDataSet)  # H(D|A)
        infoGain = baseEntropy - newEntropy 
        # 最大信息增益的特征为最好
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
        return bestFeature

创建树

def majorityCnt(classList):
    '''
    返回数据集中概率最高的标签
    '''
    classCount = {}
    for vote in  classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    '''
    递归构建决策树
    dataSet : 数据集
    labels : 标签列表
    
    '''
    classList = [example[-1] for example in dataSet]
    # 分割后的数据集内只有一种类别时,结束递归,获取该类别
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 分割后的数据集中只有一个特征时,结束递归,获取该数据集中概率最高的目标向量
    if len(dataSet[0]) == 1 :
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 选择最优特征
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}   # 注意,这里创建字典的时候需要嵌套一个
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]   # 该最优标签的值
    uniqueVals = set(featValues)   # 标签的所有可能性
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    
    return myTree
 
# 调用
mydata = [[1, 1, 'yes'],
         [1, 1, 'yes'],
         [1, 0, 'no'],
         [0, 1, 'no'],
         [0, 1, 'no']]
labels= ['no surfacing', 'flippers']
myTree = createTree(mydata, labels)
myTree

运行结果:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

变量myTree中包含了很多代表树结构信息的嵌套字典,从左边开始是第一个划分数据集的特征,他的值也是一个字典,当值为类标签时,就说明到了叶子节点

Matpolotlib 绘制树形图

import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")   # 箭头指向 <- 指向节点 ;-> 指向说明
def plotNode(noteTxt, centerPt,parentPt, nodeType):
    '''
    绘制子节点和父节点
    '''
    createPlot.axl.annotate(noteTxt, xy=parentPt, xycoords="axes fraction", xytext=centerPt, textcoords="axes fraction", va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    
def getNumLeafs(myTree):
    '''
    测量叶子节点数
    '''
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    '''
    测量树的深度
    '''
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
    if thisDepth > maxDepth : 
        maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    '''
    在节点间填充文本信息
    '''
    xMid = (parentPt[0] - cntrPt[0])/2.0 +cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 +cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)
    
def plotTree(myTree, parentPt, nodeTxt):
    
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xoff+(1.0 + float(numLeafs)) / 2.0/plotTree.totalW, plotTree.yoff)
    plotMidText(cntrPt, parentPt, nodeTxt)   # 线之间的文字信息
    plotNode(firstStr, cntrPt, parentPt, decisionNode)   # 节点和箭头
    secondDict = myTree[firstStr]
    plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD  # y坐标下移一格
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
            plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD

def createPlot(inTree):
    '''
    绘制的准备工作,画布、轴
    '''
    fig = plt.figure(1, facecolor="white")
    fig.clf()  # Clear figure清除所有轴,但是窗口打开,这样它可以被重复使用。
    axprops = dict(xticks=[], yticks=[])
    createPlot.axl = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))   # 全局变量存放宽度,整个树的叶子节点数
    plotTree.totalD = float(getTreeDepth(inTree))   # 深度
    plotTree.xoff = -0.5/plotTree.totalW    # 假设有n个叶子结点,那么应该将宽度分为2*n份,起始位置为左移半格
    plotTree.yoff = 1.0
    plotTree(inTree, (0.5,1.0), '')   # 设整个图为1*1, 那么根节点应该在0.5,1的位置
    plt.show()


# 调用 
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
createPlot(myTree)

运行结果:
在这里插入图片描述初始的plotTree.xoff:
在这里插入图片描述

def classify(inputTree, featLabels, testVec):
    '''
    使用决策树分类
    '''
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]  == key:
            if type(secondDict[key]).__name__ == "dict":
                classLable = classify(secondDict[key], featLabels, testVec)
            else :
                classLable = secondDict[key]
    return classLable

# 调用
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
labels = ['no surfacing','flippers']
classify(myTree, labels, [1,0])
classify(myTree, labels, [1,1])

运行结果

'no'
'yes'

存储、读取决策树

def storeTree(inputTree, filename):
    '''
    保存决策树
    '''
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree, fw)
    fw.close()
    
def grabTree(filename):
    '''
    读取决策树
    '''
    import pickle
    fr = open(filename)
    return pickle.load(fr)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值