代码:
import numpy as np
import operator
#计算香农熵,度量数据集的无序程度
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCountes = {}
for featureVect in dataSet:
currentLable = featureVect[-1]
labelCountes[currentLable] = labelCountes.get(currentLable,0)+1
shannonEnt = 0.0
for key in labelCountes:
prob = float(labelCountes[key]/numEntries)
shannonEnt -= prob*np.log2(prob)
return shannonEnt
#根据给定的特征和该特征的相应取值,划分数据集
def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals =set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
#在熵函数里面已经加过负号了
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
#如果用完所有特征后,还有叶子节点里面没有统一的分类,则返回最多数对应的分类
def maiorityCnt(classList):
classCount = {}
for vote in classList:
classCount[vote] = classCount.get(vote,0) + 1
sortedClassCount = sorted(classCount.items(),operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
#创建树函数代码,labels为特征名
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
#如果全都是一个分类了,到达了递归终点,返回
if classList.count(classList[0]) == len(classList):
return classList[0]
#遍历完所有特征后返回叶子节点中实例数目最多的类别,到达递归终点返回
if len(dataSet[0]) == 1:
return maiorityCnt(classList)
#chooseBestFeatureToSplit 返回最好的特征对应于数据集的列下标
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLable = labels[bestFeat]
myTree = {bestFeatLable:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniquVals =set(featValues)
for value in uniquVals:
subLables = labels[:]
myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLables)
return myTree
def createDataSet():
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels = ['水下能生存','有脚蹼']
return dataSet,labels
if __name__ == '__main__':
# dataset,labels = createDataSet()
# shannonEnt = calcShannonEnt(dataset)
# print(shannonEnt)
# p = 1/5
# a = p*np.log2(p)
# print(-a*5)
# vocabset = set([])
# vocabset |= set(['a','b','c'])
# print(vocabset)
dataSet,labels = createDataSet()
print(createTree(dataSet,labels))
运行结果: