from math import log import operator import treePlotter #计算给定数据集的香农熵的函数 def calcShannonEnt(dataSet): # 求list的长度,表示计算参与训练的数据量 numEntries=len(dataSet) labelCounts={} # 计算分类标签label出现的次数 for featVec in dataSet: # 将当前实例的标签存储,即每一行数据的最后一个数据代表的是标签 currentLabel=featVec[-1] # 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannonEnt=0.0 # 对于 label 标签的占比,求出 label 标签的香农熵 for key in labelCounts: # 使用所有类标签的发生频率计算类别出现的概率。 prob=float(labelCounts[key])/numEntries # 计算香农熵,以 2 为底求对数 shannonEnt-=prob*log(prob,2) return shannonEnt #按照给定特征划分数据集 def splitDataSet(dataSet,axis,value): retDataSet=[] for featVec in dataSet: # axis列为value的数据集【该数据集需要排除index列】 # 判断axis列的值是否为value if featVec[axis]==value: # [:axis]表示前axis行,即若 axis 为2,就是取 featVec 的前axis行 reduceFeatVec=featVec[:axis] # [axis+1:]表示从跳过axis的axis+1行,取接下来的数据 # 收集结果值axis列为value的行【该行需要排除axis列】 reduceFeatVec.extend(featVec[axis+1:]) retDataSet.append(reduceFeatVec) return retDataSet #选择最好的数据集划分方式 def chooseBestFeatureToSplit(dataSet): # 求第一行有多少列的 Feature, 最后一列是label列嘛 numFeatures=len(dataSet[0])-1 # 数据集的原始信息熵 baseEntropy=calcShannonEnt(dataSet) # 最优的信息增益值, 和最优的Featurn编号 bestInfoGain=0.0;bestFeature=-1 #迭代所有特征 for i in range(numFeatures): #创建list# 获取对应的feature下的所有数据 featList=[example[i] for example in dataSet] # 获取剔重后的集合,使用set对list数据进行去重 uniqueVals=set(featList) # 创建一个临时的信息熵 newEntropy=0.0 # 遍历某一列的value集合,计算该列的信息熵 # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。 for value in uniqueVals: subDataSet=splitDataSet(dataSet,i,value) # 计算概率 prob=len(subDataSet)/float(len(dataSet)) # 计算信息熵 newEntropy+=prob*calcShannonEnt(subDataSet) # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值 # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。 infoGain=baseEntropy-newEntropy if(infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature def majorityCnt(classList): classCount={} for vote in classCount: if vote not in classCount.keys():classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] #训练算法:构造树的数据结构 def createTree(dataSet,labels): # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行 # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签。 # count() 函数是统计括号中的值在list中出现的次数 classList=[example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] # 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果 # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。 if len(dataSet[0])==1: return majorityCnt(classList) # 选择最优的列,得到最优列对应的label含义 bestFeat=chooseBestFeatureToSplit(dataSet) # 获取label的名称 bestFeatLabel=labels[bestFeat] # 初始化myTree myTree={bestFeatLabel:{}} # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改 # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list del(labels[bestFeat]) # 取出最优列,然后它的branch做分类 featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) for value in uniqueVals: # 求出剩余的标签label subLabels=labels[:] # 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree() myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree def createDataSet(): dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels=['no surfacing','flippers'] return dataSet,labels #测试算法:使用决策树执行分类 def classify(inputTree,featLabels,testVec): # 获取tree的根节点对于的key值 #根据python3的特性进行修改 firstStr1=list(inputTree.keys()) firstStr=firstStr1[0] # 通过key得到根节点对应的value secondDict=inputTree[firstStr] # 判断根节点名称获取根节点在label中的先后顺序,这样就知道输入的testVec怎么开始对照树来做分类 featIndex=featLabels.index(firstStr) # 测试数据,找到根节点对应的label位置,也就知道从输入的数据的第几位来开始分类 for key in secondDict.keys(): if testVec[featIndex]==key: if type(secondDict[key]).__name__=='dict': classLabel=classify(secondDict[key],featLabels,testVec) else: classLabel=secondDict[key] return classLabel myDat,labels=createDataSet() myTree=treePlotter.retrieveTree(0) #print(classify(myTree,labels,[1,0])) #print(classify(myTree,labels,[1,1])) def storeTree(inputTree,filename): import pickle fw=open(filename,'wb') pickle.dump(inputTree,fw) fw.close() def grabTree(filename): import pickle fr=open(filename,'rb') return pickle.load(fr) #print(storeTree(myTree,'classifierStorage.txt')) #print(grabTree('classifierStorage.txt')) fr=open('lenses.txt') lenses=[inst.strip().split('\t')for inst in fr.readlines()] lensesLabels=['age','prescript','astigmatic','tearRate'] lensesTree=createTree(lenses,lensesLabels)