决策树及可视化
决策树是最常用的数据挖掘算法,通过一系列数据的0-1划分最终得到结论。
- 在划分数据集时,我们要遵循大原则:将无序的数据变得更加有序,通过信息增益来实现。
信息增益即为熵的减少或者数据无序度的减少。 - 第二段代码介绍了用字典创建决策树的过程
- 决策树最大的优点就是直观,但是通过编写代码输出的值是个字典不易理解,因此我们用python自带的一个包Matplotlib来使其可视化,通过注解工具annotations
1.决策树的构造核心代码
from math import log
# 计算给定数据集的香农熵
# 信息期望值(熵) H = -(p(x1)log2 p(x1)+p(x2)log2 p(x2)+……)
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: # featVec为数据集中每一行,最后一个为标签
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2)
return shannonEnt
def createDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
# test
# myData,labels = createDataSet()
# print(myData)
# print(calcShannonEnt(myData))
# 划分数据集(axis为特征,value为特征值)
# 将某特征axis的特征值为value的数据集存于reDataSet中去(其中去掉了axis一列)
def splitDataSet(dataSet,axis,value): # axis为维度
reDataSet = []
for featVec in dataSet:
if featVec[axis]==value :
reducesFeatVec = featVec[:axis]
reducesFeatVec.extend(featVec[axis+1:]) # 注意extend与append的区别
reDataSet.append(reducesFeatVec)
return reDataSet
# test
# myData,labels = createDataSet()
# print(myData)
# print(splitDataSet(myData,0,1))
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1 # 除去标签以外的特征个数
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0 ; bestFeature = -1 #初始化最大熵以及最优特征
for i in range(numFeatures):
# 按行提取dataSet为example,提取example的每一个第i个元素
featList = [example[0] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0 #初始化当前划分的熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
# 将每一个特征i特征值为value的熵累加,即为特征i的熵
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
# 信息增益为熵的减少,即数据无序度的减少
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
# test
# myData,labels = createDataSet()
# print(myData)
# print(chooseBestFeatureToSplit(myData))
2.用字典结构创建树:
用字典结构测试
Tree = {'bestFeatLabel':{1:{}}}
Tree['bestFeatLabel'][1][0]= {'flipers':1}
print(Tree)
运行结果
创建树
import operator
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
# 若只有一类,则停止划分,返回该类别
if classList.count(classList[0])==len(dataSet):
return classList[0]
# 若dataSet中只剩下一列,则返回classList中最多的类。
if len(dataSet[0])==1:
return majorityCnt(classList)
# 选择dataSet中的最佳划分的一个特征的位置bestFeat,及特征的标签bestFeatLabel
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
# 选择用字典递归创建树
myTree = {bestFeatLabel:{}}
# 删除已加入树的标签
del[labels[bestFeat]]
# 将选取的标签(特征)值排序
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = sorted(featValues)
#
for value in uniqueVals:
subLabels = labels[:]
# {bestFeatLabel:{value:{}}}
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
# myData,labels = createDataSet()
# print(createTree(myData,labels))
dataset:
运行结果:
3.matplotlib可视化
# 使用文本注解绘制树节点
import matplotlib.pyplot as plt
from DecisionTree import *
decisionNode = dict(boxstyle = 'sawtooth',fc='0.8')
leafNode = dict(boxstyle = 'round4',fc='0.8')
arrow_args = dict(arrowstyle = '<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
xytext = centerPt,textcoords = 'axes fraction',\
va ='center',ha = 'center',bbox = nodeType,arrowprops = arrow_args)
def createPlot():
fig = plt.figure(1,facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111,frameon = False)
plotNode('node',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('leaf',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
# createPlot()
# 获取叶子节点个数
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 测试节点类型是否为dictionary
if type(secondDict[key]).__name__=='dict':
numLeafs +=getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
# 获取树的深度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 +cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 +cntrPt[1]
createPlot.ax1.text( xMid , yMid , txtString)
def plotTree(myTree , parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff +(1.0 + float(numLeafs))/2.0/plotTree.totalW,\
plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111,frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW ; plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
myData,labels = createDataSet()
myTree = createTree(myData,labels)
createPlot(myTree)
数据集:
决策树:
可视化结果: