决策树及可视化

决策树及可视化

决策树是最常用的数据挖掘算法,通过一系列数据的0-1划分最终得到结论。

  1. 在划分数据集时,我们要遵循大原则:将无序的数据变得更加有序,通过信息增益来实现。
    信息增益即为的减少或者数据无序度的减少。
  2. 第二段代码介绍了用字典创建决策树的过程
  3. 决策树最大的优点就是直观,但是通过编写代码输出的值是个字典不易理解,因此我们用python自带的一个包Matplotlib来使其可视化,通过注解工具annotations

1.决策树的构造核心代码

from math import log

# 计算给定数据集的香农熵
# 信息期望值(熵) H = -(p(x1)log2 p(x1)+p(x2)log2 p(x2)+……)
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:    # featVec为数据集中每一行,最后一个为标签
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

def createDataSet():
     dataSet = [[1,1,'yes'],
                [1,1,'yes'],
                [1,0,'no'],
                [0,1,'no'],
                [0, 1, 'no']]
     labels = ['no surfacing','flippers']
     return  dataSet,labels

# test
# myData,labels = createDataSet()
# print(myData)
# print(calcShannonEnt(myData))

# 划分数据集(axis为特征,value为特征值)
# 将某特征axis的特征值为value的数据集存于reDataSet中去(其中去掉了axis一列)
def splitDataSet(dataSet,axis,value):   # axis为维度
    reDataSet = []
    for featVec in dataSet:
        if featVec[axis]==value :
            reducesFeatVec = featVec[:axis]
            reducesFeatVec.extend(featVec[axis+1:])  # 注意extend与append的区别
            reDataSet.append(reducesFeatVec)
    return  reDataSet

# test
# myData,labels = createDataSet()
# print(myData)
# print(splitDataSet(myData,0,1))

# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1   # 除去标签以外的特征个数
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0 ; bestFeature = -1   #初始化最大熵以及最优特征
    for i in range(numFeatures):
        # 按行提取dataSet为example,提取example的每一个第i个元素
        featList = [example[0] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0    #初始化当前划分的熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            # 将每一个特征i特征值为value的熵累加,即为特征i的熵
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        # 信息增益为熵的减少,即数据无序度的减少
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

# test
# myData,labels = createDataSet()
# print(myData)
# print(chooseBestFeatureToSplit(myData))

2.用字典结构创建树:

用字典结构测试

Tree = {'bestFeatLabel':{1:{}}}
Tree['bestFeatLabel'][1][0]= {'flipers':1}
print(Tree)

运行结果
在这里插入图片描述

创建树

import operator
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    # 若只有一类,则停止划分,返回该类别
    if classList.count(classList[0])==len(dataSet):
        return classList[0]
    # 若dataSet中只剩下一列,则返回classList中最多的类。
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    # 选择dataSet中的最佳划分的一个特征的位置bestFeat,及特征的标签bestFeatLabel
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    # 选择用字典递归创建树
    myTree = {bestFeatLabel:{}}
    # 删除已加入树的标签
    del[labels[bestFeat]]
    # 将选取的标签(特征)值排序
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = sorted(featValues)
    #
    for value in uniqueVals:
        subLabels = labels[:]
        # {bestFeatLabel:{value:{}}}
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

# myData,labels = createDataSet()
# print(createTree(myData,labels))

dataset:
在这里插入图片描述
运行结果:
在这里插入图片描述

3.matplotlib可视化

# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
from DecisionTree import *
decisionNode = dict(boxstyle = 'sawtooth',fc='0.8')
leafNode = dict(boxstyle = 'round4',fc='0.8')
arrow_args = dict(arrowstyle = '<-')

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
                            xytext = centerPt,textcoords = 'axes fraction',\
                            va ='center',ha = 'center',bbox = nodeType,arrowprops = arrow_args)

def createPlot():
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111,frameon = False)
    plotNode('node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('leaf',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

# createPlot()


# 获取叶子节点个数
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 测试节点类型是否为dictionary
        if type(secondDict[key]).__name__=='dict':
            numLeafs +=getNumLeafs(secondDict[key])
        else: numLeafs +=1
    return numLeafs
#  获取树的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 +cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 +cntrPt[1]
    createPlot.ax1.text( xMid , yMid , txtString)

def plotTree(myTree , parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff +(1.0 + float(numLeafs))/2.0/plotTree.totalW,\
              plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff+1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks = [], yticks = [])
    createPlot.ax1 = plt.subplot(111,frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW ;  plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

myData,labels = createDataSet()
myTree = createTree(myData,labels)
createPlot(myTree)

数据集:
在这里插入图片描述
决策树:
在这里插入图片描述
可视化结果:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值