Python中使用Matplotlib注解绘制决策树图形

Python中使用Matplotlib注解绘制决策树图形

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8") #这行创建了一个名为decisionNode的字典,用于定义决策节点的样式。boxstyle="sawtooth"表示节点的边框样式为锯齿形,fc="0.8"表示填充颜色的透明度(这里是一个相对值,0 为完全透明,1 为完全不透明)。
leafNode = dict(boxstyle="round4", fc="0.8")       #这里创建了leafNode字典,用于定义叶子节点的样式。boxstyle="round4"表示叶子节点的边框是一种圆形样式,fc="0.8"同样是设置填充颜色的透明度。
arrow_args = dict(arrowstyle="<-") #全局变量定义箭头类型,详见:https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyArrowPatch.html#matplotlib.patches.FancyArrowPatch

def getNumLeafs(myTree):
    """
    计算树的叶子节点数
    :param myTree:传入的参数是一个字典类型的树结构
    """
    numLeafs = 0
    keysList=list( myTree.keys())   #获取决策树最顶层的Keys,返回的是一个List,但需要显示地转换为List后才能用下标索引。如果是根节点,那么就是决策树选择分类的第一个特征名(根节点只有一个Key)
    firstStr = keysList[0]          #取决策树根节点的特征名
    #firstStr = myTree.keys()[0]    Does not work,get Error
    secondDict = myTree[firstStr]   #用根节点的key去索引字典的value,返回的是一个字典,可能包含多个Key,既多个分支
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])    #如果不是叶子节点,则递归进入它的分支
        else:   numLeafs +=1                #如果当前分支下是叶子节点,则叶子数+1
    return numLeafs

def getTreeDepth(myTree):
    """
    计算树的深度
    :param myTree:传入的参数是一个字典类型的树结构
    :return:返回各分支的深度最大值
    """
    maxDepth = 0
    keysList=list( myTree.keys())    #获取决策树最顶层的Keys,返回的是一个List,但需要显示地转换为List后才能用下标索引。如果是根节点,那么就是决策树选择分类的第一个特征名(根节点只有一个Key)
    firstStr = keysList[0]           #取决策树根节点的特征名
    secondDict = myTree[firstStr]    #用根节点的key去索引字典的value,返回的是一个字典,可能包含多个Key,既多个分支
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])   #如果是子树则递归调用,每进入下一层子树,深度+1
        else:   thisDepth = 1                               #如果是叶子则深度+1
        if thisDepth > maxDepth: maxDepth = thisDepth       #每遍历完一个分支后与上个分支的深度做对比,取较大者
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    matplotlib的手册见网址:https://matplotlib.org/stable/api/pyplot_summary.html
    Matplotlib提供了一个注解工具annotate()具体参数含义:https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.annotate.html
    利用添加文本注释的方式来绘制节点和箭头(箭头作为分支)
    :param nodeTxt:要生成的节点的文本
    :param centerPt:文本的坐标
    :param parentPt:父节点的坐标
    :param nodeType:节点类型
    """
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    """
    xycoords='axes fraction' 坐标从左下角以小数的形式体现坐标位置,如果xycoords='axes pixels'则以像素的形式体现坐标位置,则可能图形缩放时会出现问题
    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args 
    :va="center"和ha="center"分别设置垂直和水平对齐方式为居中,bbox=nodeType设置节点的边框样式,根据传入的nodeType,可以是决策节点或叶子节点的样式
    arrowprops=arrow_args 使用全局变量,定义箭头类型为<-
    """



def plotMidText(cntrPt, parentPt, txtString):
    """
    def plotMidText(cntrPt, parentPt, txtString)::定义了一个名为plotMidText的函数,用于在两个点(cntrPt和parentPt)之间的中间位置添加文本。
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]:计算文本在x轴方向的中间位置坐标。
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]:计算文本在y轴方向的中间位置坐标。
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30):在计算得到的中间位置(xMid, yMid)添加文本txtString,va="center"和ha="center"设置垂直和水平对齐方式为居中,rotation=30设置文本旋转 30 度。
    """
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)    #获取树的深度
    keysList=list( myTree.keys())
    firstStr = keysList[0]
    #firstStr = myTree.keys()[0]     #the text label for this node should be this 
    """
    上次叶子节点x坐标:plotTree.xOff
    按叶子节点的总数等分划分x轴,则每个叶子节点间的距离为:disLef=1.0/numLeafs
    则当前decisionNode节点x坐标绝对位置:cntx=plotTree.xOff + disLef + (numLeafs-1)*disLef/2

    (numLeafs-1)*disLef/2 当前decisionNode节点节点总是位于它的第一个叶子和最后一个叶子的x坐标中间,假如有numLeafs个叶子,则最后
    一个叶子与第一个叶子之间有numLeafs-1段间隔,而每段间隔的距离是disLef,则最后一个叶子与第一个叶子间的距离是(numLeafs-1)*disLef,
    这时我们可以计算出当前decisionNode节点相对于它的一个叶子节点的相对距离为(numLeafs-1)*disLef/2,因为decisionNode节点总是位于
    第一个节点与最后一个节点中间。
    又因为当前decisionNode节点的第一个叶子节点的x坐标与上个叶子节点x坐标为disLef,所以当前当前decisionNode节点第一个叶子节点的
    x坐标绝对位置=plotTree.xOff + disLef ,于是当前decisionNode节点x坐标绝对位置cntx=cntx=plotTree.xOff + disLef + (numLeafs-1)*disLef/2
    将disLef换为1.0/numLeafs 则有cntx=cntx=plotTree.xOff + 1.0/numLeafs + (numLeafs-1)/2/numLeafs
    有cntx=cntx=plotTree.xOff +(2.0/2 + (numLeafs-1)/2)/numLeafs=plotTree.xOff +(1+numLeafs)/2/numLeafs
    """
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)  #在当前节点和父节点之间的中间位置添加文本nodeTxt
    plotNode(firstStr, cntrPt, parentPt, decisionNode)#:绘制当前节点,使用decisionNode样式,节点文本为第一个分类属性,位置根据cntrPt和parentPt确定。
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  #进入下一层则下一层的y坐标要递减,减少的数值为总高度/树的深度
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        # 递归调用plotTree函数来绘制子树,str(key)是传递给子树的节点文本。
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #叶子节点每次的坐标增加都是固定值
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #如果是叶子节点那么只剩一个Key作为类别标签
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))   #在叶子节点和当前节点之间添加文本。
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD     #绘制完当前层的所有子节点后,调整y轴偏移量,恢复到上一层的位置。
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    """
    定义createPlot函数,用于创建绘制决策树的图形环境并绘制决策树。
    matplot手册:https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html

    """
    fig = plt.figure(1, facecolor='white')  #创建一个新的图形,编号为 1,背景颜色为白色。
    fig.clf()
    axprops = dict(xticks=[], yticks=[]) #创建一个字典axprops,用于设置坐标轴的属性,这里将x和y轴的刻度设置为空列表,表示不显示刻度。
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #不显示坐标轴和边框的图形
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))        #计算决策树的叶子数
    plotTree.totalD = float(getTreeDepth(inTree))       #计算决策树的深度
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #将整个图形的x坐标整体向左偏移两个叶子间距离的1/2,第一个节点的y坐标为最高点1.0
    plotTree(inTree, (0.5,1.0), '') #指定根节点的父节点与根节点重合,因为最终在plotTree()计算出的第一个根节点位置也会是(0.5,1.0)
    plt.show()   #显示所绘制的决策树图形

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    """
    这是一个用于测试的数列表,包含2个树
    """
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#createPlot(thisTree)


if __name__ =='__main__':
    testTree=retrieveTree(1)
    lefs=getNumLeafs(testTree)
    dep=getTreeDepth(testTree)
    print('Leaves:%s  Depth:%s'%(lefs,dep))
    createPlot(testTree)


在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值