原理
通过提问的方式,根据不同的答案选择不同的分支, 完成不同的分类
步骤分解
1.遍历数据集, 循环计算提取每个特征的香农熵和信息增益, 选取信息增益最大的特征。 再递归计算剩余的特征顺序。 将特征排序。 并将分类结果序列化保存到磁盘当中
def chooseBestFeatureToSplit(dataSet): # 选择最好的分类特征
"""
:param dataSet: 原数据集
:return: 最好的划分特征的索引值
"""
numFeatures = len(dataSet[0]) - 1 # 获取特征数
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的信息熵
bestInfoGain = 0.0 # 初始化最好的信息熵
bestFeature = -1 # 初始化最好的用于分割的特征
for i in range(numFeatures):
# 创建唯一的分类标签列表
featList= [example[i] for example in dataSet] # 获取每个元素的第i个特征
uniqueVals = set(featList) # 数据特征去重 (此特征有几种情况)
newEntropy = 0.0
# 计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # probability,概率,可理解为权重
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大
# 计算最好的信息增益
if(infoGain > bestInfoGain): # 若新的信息增益大于之前的信息增益,则替换
bestInfoGain = infoGain
bestFeature = i # 表示最好的划分特征的索引值
return bestFeature
2.递归构建决策树
def createTree(dataSet, labels):
"""
:param dataSet: 数据集
:param labels: 标签列表, 包含了数据集中的所有特征的标签
:return:
"""
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 选取的最好特征
bestFeatLabel = labels[bestFeat] # 最好特征标签
myTree = {bestFeatLabel:{ }} # 使用字典存储树信息
# 得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] # 因为下一步传参数时是引用传参
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
3.使用Matplotlib注解绘制树形图
import matplotlib.pyplot as plt
import trees
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8") # 设置箭头的样式和背景色
arrow_args = dict(arrowstyle="<-") # 设置箭头的样式
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') #