import numpy as np
import math
import operator
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for data in dataSet:
currentLabel=data[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0
for key in labelCounts.keys():
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*math.log(prob,2)
return shannonEnt
def createDataSet():
dataSet=[[1,1,1,1,'yes'],
[1,1,1,1,'yes'],
[1,1,1,1,'no'],
[1,1,1,1,'no'],
[1,1,2,2,'no'],
[2,1,1,2,'yes'],
[2,2,1,2,'yes'],
[2,2,1,2,'yes'],
[2,2,1,2,'yes'],
[3,2,2,3,'yes'],
[3,2,2,3,'yes'],
[3,2,2,3,'yes'],
[3,1,2,3,'no'],
[3,2,2,2,'no']]
labels=['no surfacing','flippers','a','b']
return dataSet,labels
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for data in dataSet:
if(data[axis]==value):
m=data[:axis]
m.extend(data[axis+1:])
retDataSet.append(m)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1
baseEnt=calcShannonEnt(dataSet)
print('原图的香农熵为%f'%baseEnt)
bestInfoGain=0.0
bestFeature=-1
for i in range(numFeatures):
featuresList=[example[i] for example in dataSet]
uniqueVals=set(featuresList)
newEnt=0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEnt+=prob*calcShannonEnt(subDataSet)
infoGain=baseEnt-newEnt
print('第',i+1,'个香农熵为',newEnt,'信息增益为',infoGain)
if(infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature
def majorityCut(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=classCount.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
if(calcShannonEnt(dataSet)==0):
return classList[0]
if(len(dataSet[0])==1):
majorityCut(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueValues=set(featValues)
for value in uniqueValues:
subLabels=labels[:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
dataSet,labels=createDataSet()
myTree=createTree(dataSet,labels)
print(myTree)