#使用ID3算法
from math import log
import operator
#计算信息增益,熵,确定最优的划分特征 H = -Σp(xi)log2p(xi)
#信息熵代表着混乱程度,熵越高信息越混乱,需要快速降低熵值
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts:
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def createDataSet():
dataSet =[[1,1,'yes'],
[1,1,'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0,1,'no']
]
labels = ['no surfacing','flippers']
return dataSet, labels
#按照给定的特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#选择最优的特征进行数据集划分
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
#数据集的熵
baseEntropy = calcShannonEnt(dataSet)
#信息增益,取最大
bestInfoGain = 0.0
#对应的最优特征
bestFeature = -1
for i in range(numFeatures):
#每种特征对应的值,每一列对应的值
featList = [example[i] for example in dataSet]
#set集合去重
uniqueVals = set(featList)
#以第i个特征进行划分,对应不同的值都会得到一个划分的结果
for uniqueVal in uniqueVals:
subDataSet = splitDataSet(dataSet, i, uniqueVal)
prob = len(subDataSet) / float(len(dataSet))
newEntropy = calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature, bestInfoGain
#递归构建决策树 终止条件:属性被用完,划分后所有数据属于同一类别
#投票,最后节点的类别为最多数据的类别
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys() :classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
#创建树
def createTree(dataSet, labels):
classList =[example[-1] for example in dataSet]
#只有一类
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
#没有用于划分的特征了
return majorityCnt(classList)
bestFeature, bestInfoGain = chooseBestFeatureToSplit(dataSet)
bestLabel = labels[bestFeature]
#字典进行嵌套
myTree = {bestLabel:{}}
del(labels[bestFeature])
featValues = [example[bestFeature] for example in dataSet]
uniqueVals = set(featValues)
for uniqueVal in uniqueVals:
subLabel = labels[:]
#递归
myTree[bestLabel][uniqueVal] = createTree(splitDataSet(dataSet, bestFeature, uniqueVal), subLabel)
return myTree
if __name__ == "__main__":
dataSet, labels = createDataSet()
# tesda = [1,2,3,4,5,6,7,8,9]
# redu = tesda[:2]
# print(redu)
# asx =tesda[3:]
# print(asx)
#
# resdata1 = splitDataSet(dataSet, 0, 1)
# print(resdata1)
# resdata2 = splitDataSet(dataSet, 0, 0)
# print(resdata2)
# #ent = calcShannonEnt(dataSet)
# #print(ent)
# print("=============")
#
# bestFeature, bestInfoGain = chooseBestFeatureToSplit(dataSet)
# print(bestFeature, bestInfoGain)
mytree = createTree(dataSet, labels)
print(mytree)
06-11
1260
1260
04-10
1万+
1万+
10-26
5994
5994
04-25
1万+
1万+
09-15
629
629

被折叠的 条评论
为什么被折叠?



