from math import log
import operator
def calcShannonEnt(dataSet):
numEntries =len(dataSet)
labelCount = {}
for featVoc in dataSet:
currentlabel = featVoc[-1]
if currentlabel not in labelCount.keys():
labelCount[currentlabel]=0
labelCount[currentlabel] +=1
shannonEnt = 0.0
for key in labelCount:
prob =float(labelCount[key])/numEntries
# print(prob)
shannonEnt -=prob*log(prob,2)
return shannonEnt
def CreateDataSet():
dataSet = [[1,1,'yes'],
[1, 1, 'yes'],
[1,0,'no'],
[1,1,'no'],
[1,0,'no']]
label =['no surfacing','flippers']
print(dataSet,label)
return dataSet,label
# [2,4,3] | 2 4|
def splitDataSet(dataSet,axis,value): #抽取一个矩阵的指定列与value相等除外的其他列数据[1,2,2] splitDataSet(data,1,3)==> | |
# [4,5,3] |4,6 |
reData=[]
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
# print(featVec[:axis],"---",featVec[axis+1:])
# print(reduceFeatVec)
reData.append(reduceFeatVec)
return reData
def chaooseBestFeatureTosplit(dataSet):
num = len(dataSet[0])-1
baseEntropy = calcShannonEnt(dataSet)
bestinfoGain = 0.0
beseFeature = -1
for i in range(num):
featList = [x[i] for x in dataSet]
# print(featList)
uniqueVals = set(featList)
# print("uniquesVals = ",uniqueVals)
newEntropy=0.0
for j in uniqueVals:
subDataSet = splitDataSet(dataSet,i,j)
prob = len(subDataSet)/len(dataSet)
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy-newEntropy
# print(infoGain)
if infoGain > bestinfoGain:
beseFeature = i
bestinfoGain = infoGain
return beseFeature
def majoritycnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote] +=1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
return sortedClassCount
def CreateTree(dataSet,labes):
print(dataSet,'label ',labes)
classList = [x[-1] for x in dataSet]
cnt = len(set(classList))# classList.count(classList[0]) 集合不能有重复元素
# print(cnt," ",len(classList))
# print(len(classList[-1]))
if classList.count(classList[0])== len(classList): #除开最后一列,剩下的矩阵中第一列元素不相同数 == 行数
print("aaaaa")
return classList[0]
# print(len(dataSet[0]))
if len(dataSet[0])==1:
print("bbbbb")
return majoritycnt(classList)
bestFeature = chaooseBestFeatureTosplit(dataSet)
bestFeatureLabe = labes[bestFeature]
myTree = {bestFeatureLabe:{}}
del labes[bestFeature]
featValues = [x[bestFeature] for x in dataSet]
uniqueVals = set(featValues)
# print(uniqueVals)
# print(splitData)
for value in uniqueVals:
subLabel = labes[:]
# print(subLabel)
# print('-------',myTree)
myTree[bestFeatureLabe][value]=CreateTree(splitDataSet(dataSet, bestFeature, value),subLabel)
# print(myTree)
return myTree
dataSet,labels = CreateDataSet();
# print(chaooseBestFeatureTosplit(dataSet))
print(CreateTree(dataSet, labels))
决策树
最新推荐文章于 2021-10-27 22:27:20 发布