【机器学习实战】3. 决策树

本文介绍了一种决策树算法的具体实现过程,包括计算香农熵、选择最佳划分特征、构建决策树等关键步骤,并提供了详细的Python代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 18/4/10 下午5:35
# @Author  : cicada@hole
# @File    : dtree.py
# @Desc    : 决策树的实现代码
# @Link    :


'''
计算给定数据集的香农熵
1. 求出实例总数numEntries
2. 求出labelCounts字典 labelCounts[key]是各个特征的个数
3. 根据香农熵公式进行计算熵

备注:根据每个特征划分的结果计算一次熵,可以选择最优特征
'''
from math import log
def calcShannonEnt(dataSet):
    numEntries = len(dataSet) #计算数据集中的实例总数
    labelCounts = {}

    # 统计类别出现的次数
    # 放到一个数组中 key是标签,val是个数
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1

    shannonEnt = 0.0

    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return  shannonEnt

def createDataSet():
    dataSet = [
        [1,1,'yes'],
        [1,1,'yes'],
        [1,0,'no'],
        [0, 1, 'no'],
        [0, 1, 'no'],
    ]
    labels = ['no surfacing','flippers']
    return dataSet,labels

'''
====按照给定特征划分数据集====
1. 输入数据集,第几个特征axis,和特征的值
2. 遍历数据集,如果第axis个特征值是value,获取除去value的列表
    如[0,1,0,0,'yes']如果判断第2个特征为1,那么新的列表为[0,0,0,'yes']
    
    [[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    [[1, 'maybe'], [1, 'yes'], [0, 'no']]
    [[1, 'no'], [1, 'no']]
'''
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


'''
     # 遍历数据集,循环计算香农熵和划分,找到最好特征
     1. 统计特征个数,遍历之
     2. 根据每个特征,统计特征有几种,遍历
     3. 根据特征不同的情况,划分子集,计算子集实例个数,取prob,计算条件熵
     4. 求信息增益,更新结果
'''
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1 # 特征个数
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet] # 统计第i个特征有几种情况
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals: #遍历特征列表
            subDataSet = splitDataSet(dataSet, i ,value)
            prob = len(subDataSet)/float(len(dataSet)) #子集样本个数/总样本个数
            newEntropy += prob * calcShannonEnt(subDataSet) #条件熵H(余子集|特征A)
        infoGain = baseEntropy - newEntropy #信息增益(互信息)
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return  bestFeature

'''
=====多数表决法 确定叶子节点的分类=====

'''
import operator
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote] += 1
        sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

'''
========递归创建决策树========
1. 先选取最好的特征,提取名称,构建决策树字典
2. 遍历特征的种类,对余子集递归建立决策树
3. 停止条件,只剩下1类 或者特征已经划分完了,返回投票最多的类别

labels是特征的名字
'''
def createTree(dataSet, labels):

    # 递归停止条件
    classList = [example[-1] for example in dataSet] #取所有的类别
    if classList.count(classList[0]) == len(classList): #只有一类:停止划分
        return classList[0]
    if len(dataSet[0]) == 1: # 使用完了所有特征,仍不能分成唯一类别的分组
        return majorityCnt(classList)


    bestFeat = chooseBestFeatureToSplit(dataSet) #返回i 最好的特征
    bestFeatLabel = labels[bestFeat] # 提取i的特征名称
    myTree = {bestFeatLabel:{}} #特征对应的字典
    del(labels[bestFeat])

    featValues = [example[bestFeat] for example in dataSet] #取最优特征的种类
    uniqueVals = set(featValues)


    for value in uniqueVals:
        subLabels = labels[:]
        # 递归调用
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree


'''
=====使用文本注解绘制树节点=====

'''

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()  # 清除图像
    createPlot.ax1 = plt.subplot(121, frameon=False) #1行2列第一个
    createPlot.ax2 = plt.subplot(122, frameon=False)
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()


'''
======构建完整的决策树=====
1. 递归获取叶子节点数和树的深度
2.
'''


def getNumLeafs(myTree):
    numLeafs = 0
    print(myTree.keys())
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth = 1

        if thisDepth > maxDepth : maxDepth = thisDepth
    return  maxDepth

def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}},1:'no'}}}}
    ]
    return listOfTrees[i]


'''
======绘制决策树======

'''
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot().ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)/2.0/plotTree.totalW,
                               plotTree.yofff))







def main():
    dataSet, labels = createDataSet()
    shan = calcShannonEnt(dataSet)
    print(shan)

    # 加一个Maybe类别 熵越大
    # dataSet[0][-1] = 'maybe'
    shan = calcShannonEnt(dataSet)
    print(shan)

    # 测试划分数据集
    set1 = splitDataSet(dataSet, 0, 1)
    set2 = splitDataSet(dataSet, 0, 0)
    print(dataSet)
    print(set1)
    print(set2)

    bestFeat = chooseBestFeatureToSplit(dataSet)
    print(bestFeat)

    myTree = createTree(dataSet, labels)
    print(myTree)
    # {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'maybe'}}}}


if __name__ == '__main__':
    # main()

    # createPlot()
    myTree = retrieveTree(0)
    treeNum = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    print(treeNum)
    print(depth)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值