(机器学习实战)3、决策树的构建和预测(详细注释)

#python3.6

#训练样本和代码压缩包:https://pan.baidu.com/s/14iacLr08aucyOTcUBMEhLw

#coding:utf-8
#定义文本框和箭头格式
from numpy import *
from scipy import *
from math import log

"""
	shannon熵思路
	1、得到训练样本行数
	2、建立字典
	3、提取标签
	4、用字典统计标签
	5、运用shannon公式计算shannon熵
	"""
def calcShannonEnt(dataSet):
	numEntries=len(dataSet)
	labelCounts={}
	for featvec in dataSet:
		currentLabel=featvec[-1]
		if currentLabel not in labelCounts.keys():
			labelCounts[currentLabel]=0
		labelCounts[currentLabel]+=1
	shannonEnt=0.0
	for key in labelCounts:
		prob=float(labelCounts[key])/numEntries
		shannonEnt-=prob*log(prob,2)
	return shannonEnt

'''
	简单的鉴定数据集
	'''	
def creatDataSet():
	dataSet=[['1','1','yes'],['1','1','yes'],['1','0','no'],['0','1','no'],['0','1','no']]
	labels=['no surfacing','flippers']
	return dataSet ,labels

#测试一下结果
# ~ dataSet ,labels=creatDataSet()
# ~ print(calcShannonEnt(dataSet))


'''
	划分数据集思路:
	1、建立列表
	2、利用列表切割实现数据集划分
	3、利用append函数把划分后的reducedFeatVec单元加入列表
	4、返回列表
	'''
def splitDataSet(dataSet,axis,value):
	retDataSet = []
	for featVec in dataSet:
		# ~ print(featVec[axis])
		if(featVec[axis] == value):
			reducedfeatVec = featVec[:axis]#将返回的dataSet中将不会出现axis这一列
			reducedfeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedfeatVec)
	return retDataSet

#测试一下结果
# ~ dataSet,labels=creatDataSet()
# ~ print(dataSet)
# ~ print(splitDataSet(dataSet,0,'0'))


'''
	选择最好的数据划分方式思路(比较信息熵):
	1、得到特征数
	2、用列表推导得到某一特征的全部
	3、得到这一个特征中的特征值种类
	4、计算该特征值占该特征的比例
	5、用shannon熵计算信息熵
	6、找出信息熵最大的即为最好的划分特征
	'''
def chooseBestFeatureToSplit(dataSet):
	numFeatures=len(dataSet[0])-1              #减去标签后的列数
	baseEntropy=calcShannonEnt(dataSet)        
	bestInfoGain=0.0;bestFeature=-1
	for i in range(numFeatures):
		featlist=[example[i] for example in dataSet] #列表推导建立某个特征的列表
		uniqueVals=set(featlist)                     #对上一步生成的列表去掉重复的特征值
		newentropy=0.0
		for value in uniqueVals:
			subDataSet=splitDataSet(dataSet,i,value)
			prob=len(subDataSet)/float(len(dataSet)) #某特征值占该特征的比重
			newentropy+=prob*calcShannonEnt(subDataSet)
		infoGain=baseEntropy-newentropy              #信息熵的计算方法第一项 - 第二项
		if(infoGain>bestInfoGain):
			bestInfoGain=infoGain
			bestFeature=i
	return bestFeature
		
#测试一下效果
# ~ dataSet,labels=creatDataSet()		
# ~ print(chooseBestFeatureToSplit(dataSet))

'''
	构建决策树思路:
	A:多数表决的方法:
				  1、国建字典统计某特征在该特征中的个数 
				  2、按个数进行降序排序 
				  3、挑选最大的
	B:创建数的函数代码:
					1、得到标签列表
					2、两个停止条件
					3、选出最好的特征做成根节点
					4、列表推导得到某特征列表
					5、递归生成树
	'''
def majorityCnt(classList): #A
	classCount={}
	for vote in classList:
		if vote not in classCount.keys():classCount[vote]=0
		classCount[vote]+=1
	sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=Ture)
	return sortedClassCount[0][0]

def createTree(dataSet,labels): #B
	classList=[example[-1] for example in dataSet]
	if classList.count(classList[0])==len(classList): #统计classList[0]的个数看是否与classList的长度相等,注意和字典统计数字的区别,字典可以用for统计全部特征数目,count统计只能使单个特征,
		return classList[0]
	if len(dataSet[0])==1:
		return majorityCnt(classList)
	bestFeat=chooseBestFeatureToSplit(dataSet)
	bestFeatLabel=labels[bestFeat]
	myTree={bestFeatLabel:{}}
	subLabels=labels[:]
	del(subLabels[bestFeat])
	featValues=[example[bestFeat] for example in dataSet]
	uniqueVals=set(featValues)
	for value in uniqueVals:
		myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
	return myTree
#测试一下效果
# ~ dataSet,labels=creatDataSet()
# ~ print(createTree(dataSet,labels))







import matplotlib.pyplot as plt
#下面的dict相当字典{"boxstyle":"sawtooth","fc":"0.8"}
decisionNode=dict(boxstyle="sawtooth",fc="0.8")     
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

#注释nodeTxt ,centerPt注解框位置,parentPt起点位置
#createPlot.ax1.annotate看不懂可以参考https://www.cnblogs.com/DaleSong/p/5348489.html
def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',  
    xytext=centerPt, textcoords='axes fraction',  
    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )  

def createPlot():
	fig=plt.figure(1,facecolor='white')
	fig.clf()
	createPlot.ax1=plt.subplot(111,frameon=False)
	plotNode('a desion node',(0.5,0.1),(0.1,0.5),decisionNode)
	plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
	plt.show()
	


'''
	获取叶子结点的数目:
	1、判断数据类型是不是字典。
	2、递归
	'''
def getNumLeafs(myTree):
	numLeafs=0
	firstStr=list(myTree.keys())[0]       
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key])==dict:      #python2与python3有区别参考https://blog.youkuaiyun.com/qq_33363973/article/details/77878122
			numLeafs+=getNumLeafs(secondDict[key])
		else: numLeafs+=1
	return numLeafs

'''
	获取树的深度:
	1、判断数据类型是不是字典
	2、递归
		'''
def getTreeDepth(myTree):
	maxDepth=0
	firstStr=list(myTree.keys())[0]
	scondDict=myTree[firstStr]
	for key in list(scondDict.keys()):
		if type(scondDict[key])==dict:       #python2与python3有区别参考https://blog.youkuaiyun.com/qq_33363973/article/details/77878122
			thisDepth=1+getTreeDepth(scondDict[key])
		else: thisDepth=1
		if thisDepth>maxDepth:maxDepth=thisDepth
	return  maxDepth
#测试一下结果
# ~ dataSet,labels=creatDataSet()
# ~ myTree=createTree(dataSet,labels)
# ~ print(getNumLeafs(myTree))
# ~ print(getTreeDepth(myTree))


'''构建已知树:
	这个是测试用的
	'''
def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]
    
    
'''找到父节点和字节点之间的中间位置放置0或1
	txtString即为string类型,看plotTree函数中的用法可知
	'''
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)


'''
	绘制树节点思路:
	1、得到叶子结点和深度
	2、得到键值列表中第一个键值firstStr
	
	'''
def plotTree(myTree,parentPt,nodeTxt):#计算所有叶节点的位置,并绘制叶节点以及0和1的位置
    numLeafs = getNumLeafs(myTree)#首先计算宽和高
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
   #cntrPt = (-0.5*固定叶子数 + (1.0 +该树叶子数))/2.0/固定叶子数,1.0)//不理解2.0的含义可以把它改成1.0.和3.0试一试。
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#计算字节点的位置
    plotMidText(cntrPt,parentPt,nodeTxt)#绘制0或者1
    plotNode(firstStr,cntrPt,parentPt,decisionNode)#绘制最开始的父节点
    secondDict = myTree[firstStr]
   #plotTree.yOff = plotTree.yOff - 1.0/深度      //这个表示纵坐标,其实就是从根节点从下往下画
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#因为父节点在最上面,则需要往下减去偏移量
    for key in list(secondDict.keys()):
        if type(secondDict[key])== dict:
            plotTree(secondDict[key],cntrPt,str(key))#如果是字典则递归调用
        else:#如果不是字典,则计算x偏移,就是叶节点的位置,绘制叶节点以及0或者1
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#把所有的叶节点都计算完之后,将把y偏移加回来,使最后的y在父节点上

#后面的代码参考https://www.cnblogs.com/zy230530/p/6813250.html
#后面的代码参考https://blog.youkuaiyun.com/chuhang_zhqr/article/details/50731489
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))#这都是全局变量
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')#绘制节点树形图
    plt.show()
    
'''进行决策'''  
def classify(inputTree,featLabels,testVec):#根据已有的决策树,对给出的数据进行分类
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)#这里是将标签字符串转换成索引数字
    for key in secondDict.keys():
        if testVec[featIndex] == key:#如果key值等于给定的标签时
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key],featLabels,testVec)#递归调用分类
            else: classLabel = secondDict[key]#此数据的分类结果
    return classLabel

'''由于构建决策树是很耗时的,但用创建好的决策树就可以很快解决分类问题,最好每次次执行分类时调用已构造好的决策树,pickle可以存储对象,也可以读出对象,字典对象也不例外,k近邻不能持久分类,必须每次都计算'''
def storeTree(inputTree,filename):
	import pickle
	fw  = open(filename,'wb+') #'wb+'因为pickle存储方式默认是二进制方式
	pickle.dump(inputTree,fw)
	fw.close()

def grabTree(filename):
	import pickle
	fr = open(filename, 'rb+')
	return pickle.load(fr)
    
'''
	测试一下
	'''
print('用testVec测试一下功能')
testVec=['1','0']
dataSet,featLabels=creatDataSet()
inputTree=createTree(dataSet,featLabels)
print(inputTree)
createPlot(inputTree)
print(classify(inputTree,featLabels,testVec))
print('=======================================')
print('测试一下存储功能')
storeTree(inputTree,'saveTree.txt')
print((grabTree('saveTree.txt')))
createPlot(grabTree('saveTree.txt'))
print(classify(grabTree('saveTree.txt'),featLabels,testVec))
print('=======================================')

'''
	读取文件建立决策树
	'''
def predictLensesType(filename):
    #打开文本数据
    fr=open(filename)
    #将文本数据的每一个数据行按照tab键分割,并依次存入lenses
    lenses=[inst.strip().split('\t') for inst in fr.readlines()]
    #创建并存入特征标签列表
    lensesLabels=['age','prescript','astigmatic','tearRate']
    #根据继续文件得到的数据集和特征标签列表创建决策树
    lensesTree=createTree(lenses,lensesLabels)
    return lensesTree
    
print('=======================================')
print('读取文件建立决策树')
lensesTree=predictLensesType('lenses.txt')
createPlot(lensesTree)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值