#python3.6
#训练样本和代码压缩包:https://pan.baidu.com/s/14iacLr08aucyOTcUBMEhLw
#coding:utf-8
#定义文本框和箭头格式
from numpy import *
from scipy import *
from math import log
"""
shannon熵思路
1、得到训练样本行数
2、建立字典
3、提取标签
4、用字典统计标签
5、运用shannon公式计算shannon熵
"""
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for featvec in dataSet:
currentLabel=featvec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*log(prob,2)
return shannonEnt
'''
简单的鉴定数据集
'''
def creatDataSet():
dataSet=[['1','1','yes'],['1','1','yes'],['1','0','no'],['0','1','no'],['0','1','no']]
labels=['no surfacing','flippers']
return dataSet ,labels
#测试一下结果
# ~ dataSet ,labels=creatDataSet()
# ~ print(calcShannonEnt(dataSet))
'''
划分数据集思路:
1、建立列表
2、利用列表切割实现数据集划分
3、利用append函数把划分后的reducedFeatVec单元加入列表
4、返回列表
'''
def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
# ~ print(featVec[axis])
if(featVec[axis] == value):
reducedfeatVec = featVec[:axis]#将返回的dataSet中将不会出现axis这一列
reducedfeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedfeatVec)
return retDataSet
#测试一下结果
# ~ dataSet,labels=creatDataSet()
# ~ print(dataSet)
# ~ print(splitDataSet(dataSet,0,'0'))
'''
选择最好的数据划分方式思路(比较信息熵):
1、得到特征数
2、用列表推导得到某一特征的全部
3、得到这一个特征中的特征值种类
4、计算该特征值占该特征的比例
5、用shannon熵计算信息熵
6、找出信息熵最大的即为最好的划分特征
'''
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1 #减去标签后的列数
baseEntropy=calcShannonEnt(dataSet)
bestInfoGain=0.0;bestFeature=-1
for i in range(numFeatures):
featlist=[example[i] for example in dataSet] #列表推导建立某个特征的列表
uniqueVals=set(featlist) #对上一步生成的列表去掉重复的特征值
newentropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet)) #某特征值占该特征的比重
newentropy+=prob*calcShannonEnt(subDataSet)
infoGain=baseEntropy-newentropy #信息熵的计算方法第一项 - 第二项
if(infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature
#测试一下效果
# ~ dataSet,labels=creatDataSet()
# ~ print(chooseBestFeatureToSplit(dataSet))
'''
构建决策树思路:
A:多数表决的方法:
1、国建字典统计某特征在该特征中的个数
2、按个数进行降序排序
3、挑选最大的
B:创建数的函数代码:
1、得到标签列表
2、两个停止条件
3、选出最好的特征做成根节点
4、列表推导得到某特征列表
5、递归生成树
'''
def majorityCnt(classList): #A
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=Ture)
return sortedClassCount[0][0]
def createTree(dataSet,labels): #B
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList): #统计classList[0]的个数看是否与classList的长度相等,注意和字典统计数字的区别,字典可以用for统计全部特征数目,count统计只能使单个特征,
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
subLabels=labels[:]
del(subLabels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
for value in uniqueVals:
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
#测试一下效果
# ~ dataSet,labels=creatDataSet()
# ~ print(createTree(dataSet,labels))
import matplotlib.pyplot as plt
#下面的dict相当字典{"boxstyle":"sawtooth","fc":"0.8"}
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#注释nodeTxt ,centerPt注解框位置,parentPt起点位置
#createPlot.ax1.annotate看不懂可以参考https://www.cnblogs.com/DaleSong/p/5348489.html
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def createPlot():
fig=plt.figure(1,facecolor='white')
fig.clf()
createPlot.ax1=plt.subplot(111,frameon=False)
plotNode('a desion node',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
'''
获取叶子结点的数目:
1、判断数据类型是不是字典。
2、递归
'''
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key])==dict: #python2与python3有区别参考https://blog.youkuaiyun.com/qq_33363973/article/details/77878122
numLeafs+=getNumLeafs(secondDict[key])
else: numLeafs+=1
return numLeafs
'''
获取树的深度:
1、判断数据类型是不是字典
2、递归
'''
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
scondDict=myTree[firstStr]
for key in list(scondDict.keys()):
if type(scondDict[key])==dict: #python2与python3有区别参考https://blog.youkuaiyun.com/qq_33363973/article/details/77878122
thisDepth=1+getTreeDepth(scondDict[key])
else: thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
#测试一下结果
# ~ dataSet,labels=creatDataSet()
# ~ myTree=createTree(dataSet,labels)
# ~ print(getNumLeafs(myTree))
# ~ print(getTreeDepth(myTree))
'''构建已知树:
这个是测试用的
'''
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
'''找到父节点和字节点之间的中间位置放置0或1
txtString即为string类型,看plotTree函数中的用法可知
'''
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
'''
绘制树节点思路:
1、得到叶子结点和深度
2、得到键值列表中第一个键值firstStr
'''
def plotTree(myTree,parentPt,nodeTxt):#计算所有叶节点的位置,并绘制叶节点以及0和1的位置
numLeafs = getNumLeafs(myTree)#首先计算宽和高
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
#cntrPt = (-0.5*固定叶子数 + (1.0 +该树叶子数))/2.0/固定叶子数,1.0)//不理解2.0的含义可以把它改成1.0.和3.0试一试。
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#计算字节点的位置
plotMidText(cntrPt,parentPt,nodeTxt)#绘制0或者1
plotNode(firstStr,cntrPt,parentPt,decisionNode)#绘制最开始的父节点
secondDict = myTree[firstStr]
#plotTree.yOff = plotTree.yOff - 1.0/深度 //这个表示纵坐标,其实就是从根节点从下往下画
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#因为父节点在最上面,则需要往下减去偏移量
for key in list(secondDict.keys()):
if type(secondDict[key])== dict:
plotTree(secondDict[key],cntrPt,str(key))#如果是字典则递归调用
else:#如果不是字典,则计算x偏移,就是叶节点的位置,绘制叶节点以及0或者1
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#把所有的叶节点都计算完之后,将把y偏移加回来,使最后的y在父节点上
#后面的代码参考https://www.cnblogs.com/zy230530/p/6813250.html
#后面的代码参考https://blog.youkuaiyun.com/chuhang_zhqr/article/details/50731489
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))#这都是全局变量
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')#绘制节点树形图
plt.show()
'''进行决策'''
def classify(inputTree,featLabels,testVec):#根据已有的决策树,对给出的数据进行分类
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)#这里是将标签字符串转换成索引数字
for key in secondDict.keys():
if testVec[featIndex] == key:#如果key值等于给定的标签时
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],featLabels,testVec)#递归调用分类
else: classLabel = secondDict[key]#此数据的分类结果
return classLabel
'''由于构建决策树是很耗时的,但用创建好的决策树就可以很快解决分类问题,最好每次次执行分类时调用已构造好的决策树,pickle可以存储对象,也可以读出对象,字典对象也不例外,k近邻不能持久分类,必须每次都计算'''
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'wb+') #'wb+'因为pickle存储方式默认是二进制方式
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb+')
return pickle.load(fr)
'''
测试一下
'''
print('用testVec测试一下功能')
testVec=['1','0']
dataSet,featLabels=creatDataSet()
inputTree=createTree(dataSet,featLabels)
print(inputTree)
createPlot(inputTree)
print(classify(inputTree,featLabels,testVec))
print('=======================================')
print('测试一下存储功能')
storeTree(inputTree,'saveTree.txt')
print((grabTree('saveTree.txt')))
createPlot(grabTree('saveTree.txt'))
print(classify(grabTree('saveTree.txt'),featLabels,testVec))
print('=======================================')
'''
读取文件建立决策树
'''
def predictLensesType(filename):
#打开文本数据
fr=open(filename)
#将文本数据的每一个数据行按照tab键分割,并依次存入lenses
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
#创建并存入特征标签列表
lensesLabels=['age','prescript','astigmatic','tearRate']
#根据继续文件得到的数据集和特征标签列表创建决策树
lensesTree=createTree(lenses,lensesLabels)
return lensesTree
print('=======================================')
print('读取文件建立决策树')
lensesTree=predictLensesType('lenses.txt')
createPlot(lensesTree)