前言
接着上一篇继续学习。
首先感谢博主:Jack-Cui
主页:http://blog.youkuaiyun.com/c406495762
决策树博文地址:https://blog.youkuaiyun.com/c406495762/article/details/76262487
这篇博文对书上的内容很形象的进行了表达,通俗易懂,用自己的实例来进行讲解,比书上讲的清楚太多,于是我才开始了学习,感激不尽,真心推荐。
我这篇博文大多从它的博文中摘抄,但也是我一个字一个敲出来的,算法我也是自己算过的,算是学完它的博文的一个总结吧,如果还看不明白的可以直接看他的吧。
明天我会继续按照它的博客学习。
进入正题。
决策树的构建
能构建决策树的算法有很多,本篇只介绍使用ID3算法构建决策树。
1 ID3算法
ID3算法的核心是在决策树各个结点上对应信息增益准则选择特征,递归地构建决策树。具体方法是:从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子节点;再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增益均很小或没有特征可以选择为止,最后得到一个决策树。ID3相当于用极大似然法进行概率模型的选择。
根据上篇文章求得的信息增益结果,由于特征A3(有自己的房子)的信息增益值最大,所以选择特征A3作为根结点的特征。它将训练集D划分为两个子集D1(A3取值为”是”)和D2(A3取值为”否”)。由于D1只有同一类的样本点,所以它成为一个叶结点,结点的类标记为“是”。
对D2则需要从特征A1(年龄),A2(有工作)和A4(信贷情况)中选择新的特征,计算各个特征的信息增益:
-
g(D2,A1) = H(D2) - H(D2 | A1) = 0.251
-
g(D2,A2) = H(D2) - H(D2 | A2) = 0.918
-
g(D2,A3) = H(D2) - H(D2 | A3) = 0.474
根据计算,选择信息增益最大的特征A2(有工作)作为结点的特征。由于A2有两个可能取值,从这一结点引出两个子结点:一个对应”是”(有工作)的子结点,包含3个样本,它们属于同一类,所以这是一个叶结点,类标记为”是”;另一个是对应”否”(无工作)的子结点,包含6个样本,它们也属于同一类,所以这也是一个叶结点,类标记为”否”。
这样就生成了一个决策树,该决策树只用了两个特征(有两个内部结点),生成的决策树如下图所示。
代码实现:
# -*- coding: UTF-8 -*-
from math import log
import operator
"""
函数说明:计算给定数据集的经验熵(香农熵)
Parameters:
dataSet - 数据集
Returns:
shannonEnt - 经验熵(香农熵)
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def calcShannonEnt(dataSet):
numEntires = len(dataSet) #返回数据集的行数
labelCounts = {} #保存每个标签(Label)出现次数的字典
for featVec in dataSet: #对每组特征向量进行统计
currentLabel = featVec[-1] #提取标签(Label)信息
if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #Label计数
shannonEnt = 0.0 #经验熵(香农熵)
for key in labelCounts: #计算香农熵
prob = float(labelCounts[key]) / numEntires #选择该标签(Label)的概率
shannonEnt -= prob * log(prob, 2) #利用公式计算
return shannonEnt #返回经验熵(香农熵)
"""
函数说明:创建测试数据集
Parameters:
无
Returns:
dataSet - 数据集
labels - 特征标签
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-20
"""
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], #数据集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况'] #特征标签
return dataSet, labels #返回数据集和分类属性
"""
函数说明:按照给定特征划分数据集
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建返回的数据集列表
for featVec in dataSet: #遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去掉axis特征
reducedFeatVec.extend(featVec[axis+1:]) #将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet #返回划分后的数据集
"""
函数说明:选择最优特征
Parameters:
dataSet - 数据集
Returns:
bestFeature - 信息增益最大的(最优)特征的索引值
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-20
"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #特征数量
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0 #信息增益
bestFeature = -1 #最优特征的索引值
for i in range(numFeatures): #遍历所有特征
#获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #创建set集合{},元素不可重复
newEntropy = 0.0 #经验条件熵
for value in uniqueVals: #计算信息增益
subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy #信息增益
# print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益
if (infoGain > bestInfoGain): #计算信息增益
bestInfoGain = infoGain #更新信息增益,找到最大的信息增益
bestFeature = i #记录信息增益最大的特征的索引值
return bestFeature #返回信息增益最大的特征的索引值
"""
函数说明:统计classList中出现此处最多的元素(类标签)
Parameters:
classList - 类标签列表
Returns:
sortedClassCount[0][0] - 出现此处最多的元素(类标签)
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def majorityCnt(classList):
classCount = {}
for vote in classList: #统计classList中每个元素出现的次数
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) #根据字典的值降序排序
return sortedClassCount[0][0] #返回classList中出现次数最多的元素
"""
函数说明:创建决策树
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
Returns:
myTree - 决策树
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-25
"""
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet] #取分类标签(是否放贷:yes or no)
if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #选择最优特征
bestFeatLabel = labels[bestFeat] #最优特征的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}} #根据最优特征的标签生成树
del(labels[bestFeat]) #删除已经使用特征标签
featValues = [example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) #去掉重复的属性值
for value in uniqueVals: #遍历特征,创建决策树。
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
print(myTree)
此时决策树就已经构建好了,但是光把字典内容打印出来就比较抽象了,我们可以通过Matplotlib对其进行可视化操作。
2 可视化决策树
代码都是有关Matplotlib的,需要用到的函数有:
- getNumLeafs:获取决策树叶子结点的数目
- getTreeDepth:获取决策树的层数
- plotNode:绘制结点
- plotMidText:标注有向边属性值
- plotTree:绘制决策树
- createPlot:创建绘制面板
直接上代码:
# -*- coding: UTF-8 -*-
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
import operator
"""
函数说明:计算给定数据集的经验熵(香农熵)
Parameters:
dataSet - 数据集
Returns:
shannonEnt - 经验熵(香农熵)
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def calcShannonEnt(dataSet):
numEntires = len(dataSet) #返回数据集的行数
labelCounts = {} #保存每个标签(Label)出现次数的字典
for featVec in dataSet: #对每组特征向量进行统计
currentLabel = featVec[-1] #提取标签(Label)信息
if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #Label计数
shannonEnt = 0.0 #经验熵(香农熵)
for key in labelCounts: #计算香农熵
prob = float(labelCounts[key]) / numEntires #选择该标签(Label)的概率
shannonEnt -= prob * log(prob, 2) #利用公式计算
return shannonEnt #返回经验熵(香农熵)
"""
函数说明:创建测试数据集
Parameters:
无
Returns:
dataSet - 数据集
labels - 特征标签
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-20
"""
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], #数据集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况'] #特征标签
return dataSet, labels #返回数据集和分类属性
"""
函数说明:按照给定特征划分数据集
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建返回的数据集列表
for featVec in dataSet: #遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去掉axis特征
reducedFeatVec.extend(featVec[axis+1:]) #将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet #返回划分后的数据集
"""
函数说明:选择最优特征
Parameters:
dataSet - 数据集
Returns:
bestFeature - 信息增益最大的(最优)特征的索引值
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-20
"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #特征数量
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0 #信息增益
bestFeature = -1 #最优特征的索引值
for i in range(numFeatures): #遍历所有特征
#获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #创建set集合{},元素不可重复
newEntropy = 0.0 #经验条件熵
for value in uniqueVals: #计算信息增益
subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy #信息增益
# print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益
if (infoGain > bestInfoGain): #计算信息增益
bestInfoGain = infoGain #更新信息增益,找到最大的信息增益
bestFeature = i #记录信息增益最大的特征的索引值
return bestFeature #返回信息增益最大的特征的索引值
"""
函数说明:统计classList中出现此处最多的元素(类标签)
Parameters:
classList - 类标签列表
Returns:
sortedClassCount[0][0] - 出现此处最多的元素(类标签)
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def majorityCnt(classList):
classCount = {}
for vote in classList: #统计classList中每个元素出现的次数
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) #根据字典的值降序排序
return sortedClassCount[0][0] #返回classList中出现次数最多的元素
"""
函数说明:创建决策树
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
Returns:
myTree - 决策树
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-25
"""
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet] #取分类标签(是否放贷:yes or no)
if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #选择最优特征
bestFeatLabel = labels[bestFeat] #最优特征的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}} #根据最优特征的标签生成树
del(labels[bestFeat]) #删除已经使用特征标签
featValues = [example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) #去掉重复的属性值
for value in uniqueVals: #遍历特征,创建决策树。
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
"""
函数说明:获取决策树叶子结点的数目
Parameters:
myTree - 决策树
Returns:
numLeafs - 决策树的叶子结点的数目
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def getNumLeafs(myTree):
numLeafs = 0 #初始化叶子
firstStr = next(iter(myTree)) #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
secondDict = myTree[firstStr] #获取下一组字典
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
"""
函数说明:获取决策树的层数
Parameters:
myTree - 决策树
Returns:
maxDepth - 决策树的层数
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def getTreeDepth(myTree):
maxDepth = 0 #初始化决策树深度
firstStr = next(iter(myTree)) #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
secondDict = myTree[firstStr] #获取下一个字典
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth #更新层数
return maxDepth
"""
函数说明:绘制结点
Parameters:
nodeTxt - 结点名
centerPt - 文本位置
parentPt - 标注的箭头位置
nodeType - 结点格式
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
arrow_args = dict(arrowstyle="<-") #定义箭头格式
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) #设置中文字体
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', #绘制结点
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
"""
函数说明:标注有向边属性值
Parameters:
cntrPt、parentPt - 用于计算标注位置
txtString - 标注的内容
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #计算标注位置
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
"""
函数说明:绘制决策树
Parameters:
myTree - 决策树(字典)
parentPt - 标注的内容
nodeTxt - 结点名
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def plotTree(myTree, parentPt, nodeTxt):
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #设置结点格式
leafNode = dict(boxstyle="round4", fc="0.8") #设置叶结点格式
numLeafs = getNumLeafs(myTree) #获取决策树叶结点数目,决定了树的宽度
depth = getTreeDepth(myTree) #获取决策树层数
firstStr = next(iter(myTree)) #下个字典
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #中心位置
plotMidText(cntrPt, parentPt, nodeTxt) #标注有向边属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制结点
secondDict = myTree[firstStr] #下一个字典,也就是继续绘制子结点
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
plotTree(secondDict[key],cntrPt,str(key)) #不是叶结点,递归调用继续绘制
else: #如果是叶结点,绘制叶结点,并标注有向边属性值
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
"""
函数说明:创建绘制面板
Parameters:
inTree - 决策树(字典)
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') #创建fig
fig.clf() #清空fig
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴
plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目
plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #x偏移
plotTree(inTree, (0.5,1.0), '') #绘制决策树
plt.show() #显示绘制结果
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
print(myTree)
createPlot(myTree)
plotNode函数的工作就是绘制各个结点,比如有自己的房子
、有工作
、yes
、no
,包括内结点和叶子结点。plotMidText函数的工作就是绘制各个有向边的属性,例如各个有向边的0
和1
。得到结果如下:
3 使用决策树执行分类
依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要决策树以及用于构造树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子结点;最后将测试数据定义为叶子结点所属的类型。在构建决策树的代码,可以看到,有个featLabels参数。它是用来干什么的?它就是用来记录各个分类结点的,在用决策树做预测的时候,我们按顺序输入需要的分类结点的属性值即可。举个例子,比如我用上述已经训练好的决策树做分类,那么我只需要提供这个人是否有房子,是否有工作这两个信息即可,无需提供冗余的信息。
看代码:
# -*- coding: UTF-8 -*-
from math import log
import operator
"""
函数说明:计算给定数据集的经验熵(香农熵)
Parameters:
dataSet - 数据集
Returns:
shannonEnt - 经验熵(香农熵)
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def calcShannonEnt(dataSet):
numEntires = len(dataSet) #返回数据集的行数
labelCounts = {} #保存每个标签(Label)出现次数的字典
for featVec in dataSet: #对每组特征向量进行统计
currentLabel = featVec[-1] #提取标签(Label)信息
if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #Label计数
shannonEnt = 0.0 #经验熵(香农熵)
for key in labelCounts: #计算香农熵
prob = float(labelCounts[key]) / numEntires #选择该标签(Label)的概率
shannonEnt -= prob * log(prob, 2) #利用公式计算
return shannonEnt #返回经验熵(香农熵)
"""
函数说明:创建测试数据集
Parameters:
无
Returns:
dataSet - 数据集
labels - 特征标签
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-20
"""
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], #数据集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况'] #特征标签
return dataSet, labels #返回数据集和分类属性
"""
函数说明:按照给定特征划分数据集
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建返回的数据集列表
for featVec in dataSet: #遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去掉axis特征
reducedFeatVec.extend(featVec[axis+1:]) #将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet #返回划分后的数据集
"""
函数说明:选择最优特征
Parameters:
dataSet - 数据集
Returns:
bestFeature - 信息增益最大的(最优)特征的索引值
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-20
"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #特征数量
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0 #信息增益
bestFeature = -1 #最优特征的索引值
for i in range(numFeatures): #遍历所有特征
#获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #创建set集合{},元素不可重复
newEntropy = 0.0 #经验条件熵
for value in uniqueVals: #计算信息增益
subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy #信息增益
# print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益
if (infoGain > bestInfoGain): #计算信息增益
bestInfoGain = infoGain #更新信息增益,找到最大的信息增益
bestFeature = i #记录信息增益最大的特征的索引值
return bestFeature #返回信息增益最大的特征的索引值
"""
函数说明:统计classList中出现此处最多的元素(类标签)
Parameters:
classList - 类标签列表
Returns:
sortedClassCount[0][0] - 出现此处最多的元素(类标签)
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-24
"""
def majorityCnt(classList):
classCount = {}
for vote in classList: #统计classList中每个元素出现的次数
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) #根据字典的值降序排序
return sortedClassCount[0][0] #返回classList中出现次数最多的元素
"""
函数说明:创建决策树
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
Returns:
myTree - 决策树
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-25
"""
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet] #取分类标签(是否放贷:yes or no)
if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #选择最优特征
bestFeatLabel = labels[bestFeat] #最优特征的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}} #根据最优特征的标签生成树
del(labels[bestFeat]) #删除已经使用特征标签
featValues = [example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) #去掉重复的属性值
for value in uniqueVals: #遍历特征,创建决策树。
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
"""
函数说明:使用决策树分类
Parameters:
inputTree - 已经生成的决策树
featLabels - 存储选择的最优特征标签
testVec - 测试数据列表,顺序对应最优特征标签
Returns:
classLabel - 分类结果
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-25
"""
def classify(inputTree, featLabels, testVec):
firstStr = next(iter(inputTree)) #获取决策树结点
secondDict = inputTree[firstStr] #下一个字典
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
testVec = [0,1] #测试数据
result = classify(myTree, featLabels, testVec)
if result == 'yes':
print('放贷')
if result == 'no':
print('不放贷')
这里只增加了classify函数,用于决策树分类。输入测试数据[0,1]
,它代表没有房子,但是有工作,结果结果如下所示:
但是,如果按照原来的数据集进行分类的话,我们都得对原有训练数据来创建决策树,这样是非常麻烦的,所以我们希望提供一种可以存储决策树的方法,不必每次修改测试数据都对决策树重建,浪费资源与时间。
3.1 决策树的存储
构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。然而用创建好的决策树解决分类问题,则可以很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用Python模块pickle序列化对象。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。
假设我们已经得到决策树{‘有自己的房子’: {0: {‘有工作’: {0: ‘no’, 1: ‘yes’}}, 1: ‘yes’}},使用pickle.dump存储决策树。
看代码:
# -*- coding: UTF-8 -*-
import pickle
"""
函数说明:存储决策树
Parameters:
inputTree - 已经生成的决策树
filename - 决策树的存储文件名
Returns:
无
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-25
"""
def storeTree(inputTree, filename):
with open(filename, 'wb') as fw:
pickle.dump(inputTree, fw)
if __name__ == '__main__':
myTree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}
storeTree(myTree, 'classifierStorage.txt')
在存储好决策树之后,可以用pickle.load方法进行载入。
代码如下:
# -*- coding: UTF-8 -*-
import pickle
"""
函数说明:读取决策树
Parameters:
filename - 决策树的存储文件名
Returns:
pickle.load(fr) - 决策树字典
Author:
Jack Cui
Blog:
http://blog.youkuaiyun.com/c406495762
Modify:
2017-07-25
"""
def grabTree(filename):
fr = open(filename, 'rb')
return pickle.load(fr)
if __name__ == '__main__':
myTree = grabTree('classifierStorage.txt')
print(myTree)
4.使用决策树预测隐形眼镜的类型(实战)
4.1 背景
隐形眼镜数据集是非常著名的数据集,它包含很多换着眼部状态的观察条件以及医生推荐的隐形眼镜类型。隐形眼镜类型包括硬材质(hard)、软材质(soft)以及不适合佩戴隐形眼镜(no lenses)。
数据集如下:
一共有24组数据,数据的Labels依次是age
、prescript
、astigmatic
、tearRate
、class
,也就是第一列是年龄,第二列是症状,第三列是是否散光,第四列是眼泪数量,第五列是最终的分类标签。
接下来可以直接使用Python程序构建决策树进行分类。也可以使用Sklearn构建决策树
4.2 Sklearn构建决策树
sklearn.tree模块提供了决策树模型,用于解决分类问题和回归问题。
官方英文文档地址:http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
本次实战内容使用的是DecisionTreeClassifier和export_graphviz,前者用于决策树构建,后者用于决策树可视化。
-
使用DecisionTreeClassifier构建决策树
一般来说,直接通过读取文件获取数据集,取出标签再直接调用函数即可。如下:
# -*- coding: UTF-8 -*- from sklearn import tree if __name__ == '__main__': fr = open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] print(lenses) lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] clf = tree.DecisionTreeClassifier() lenses = clf.fit(lenses, lensesLabels)
但结果却不尽人意,因为fit函数不能接收String类型的数据,通过打印的信息可以看到,数据都是
string
类型的。在使用fit()函数之前,我们需要对数据集进行编码。为了对String类型的数据序列化,需要先生成pandas数据,方便我们序列化。这里我使用的方法是,原始数据->字典->pandas数据,编写代码如下:# -*- coding: UTF-8 -*- import pandas as pd if __name__ == '__main__': with open('lenses.txt', 'r') as fr: #加载文件 lenses = [inst.strip().split('\t') for inst in fr.readlines()] #处理文件 lenses_target = [] #提取每组数据的类别,保存在列表里 for each in lenses: lenses_target.append(each[-1]) lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] #特征标签 lenses_list = [] #保存lenses数据的临时列表 lenses_dict = {} #保存lenses数据的字典,用于生成pandas for each_label in lensesLabels: #提取信息,生成字典 for each in lenses: lenses_list.append(each[lensesLabels.index(each_label)]) lenses_dict[each_label] = lenses_list lenses_list = [] print(lenses_dict) #打印字典信息 lenses_pd = pd.DataFrame(lenses_dict) #生成pandas.DataFrame print(lenses_pd)
记得安装pandas模块,我的是Python3.8版本的,可以直接百度下载一个
执行代码看结果可知,这样就生成了pandas数据
然后再进行数据序列化,代码如下:
# -*- coding: UTF-8 -*- import pandas as pd from sklearn.preprocessing import LabelEncoder import pydotplus from six import StringIO##在0.21版本被弃用,0.23版本被删除,所以直接安装six包,然后改成six导入就可 if __name__ == '__main__': with open('lenses.txt', 'r') as fr: #加载文件 lenses = [inst.strip().split('\t') for inst in fr.readlines()] #处理文件 lenses_target = [] #提取每组数据的类别,保存在列表里 for each in lenses: lenses_target.append(each[-1]) lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] #特征标签 lenses_list = [] #保存lenses数据的临时列表 lenses_dict = {} #保存lenses数据的字典,用于生成pandas for each_label in lensesLabels: #提取信息,生成字典 for each in lenses: lenses_list.append(each[lensesLabels.index(each_label)]) lenses_dict[each_label] = lenses_list lenses_list = [] # print(lenses_dict) #打印字典信息 lenses_pd = pd.DataFrame(lenses_dict) #生成pandas.DataFrame print(lenses_pd) #打印pandas.DataFrame le = LabelEncoder() #创建LabelEncoder()对象,用于序列化 for col in lenses_pd.columns: #为每一列序列化 lenses_pd[col] = le.fit_transform(lenses_pd[col]) print(lenses_pd)
这样就得到了结果:
接下来就可以使用fit()函数构建决策树了,在使用如上代码之前记得安装sklearn库(版本低一点0.22版本就好)和six库、pydotplus库、pandas库还有安装sklearn库必须用到的scipy库。
-
使用Graphviz可视化决策树
安装方法和步骤都可以直接参照https://blog.youkuaiyun.com/c406495762/article/details/76262487
安装成功并配置完成之后,直接上代码:
# -*- coding: UTF-8 -*- from sklearn.preprocessing import LabelEncoder, OneHotEncoder from six import StringIO#直接修改成six,因为scikit_learn在21版本就弃用,在23版本就被删除,现在装的版本是22 from sklearn import tree import pandas as pd import numpy as np import pydotplus if __name__ == '__main__': with open('lenses.txt', 'r') as fr: #加载文件 lenses = [inst.strip().split('\t') for inst in fr.readlines()] #处理文件 lenses_target = [] #提取每组数据的类别,保存在列表里 for each in lenses: lenses_target.append(each[-1]) print(lenses_target) lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] #特征标签 lenses_list = [] #保存lenses数据的临时列表 lenses_dict = {} #保存lenses数据的字典,用于生成pandas for each_label in lensesLabels: #提取信息,生成字典 for each in lenses: lenses_list.append(each[lensesLabels.index(each_label)]) lenses_dict[each_label] = lenses_list lenses_list = [] # print(lenses_dict) #打印字典信息 lenses_pd = pd.DataFrame(lenses_dict) #生成pandas.DataFrame # print(lenses_pd) #打印pandas.DataFrame le = LabelEncoder() #创建LabelEncoder()对象,用于序列化 for col in lenses_pd.columns: #序列化 lenses_pd[col] = le.fit_transform(lenses_pd[col]) # print(lenses_pd) #打印编码信息 clf = tree.DecisionTreeClassifier(max_depth = 4) #创建DecisionTreeClassifier()类 clf = clf.fit(lenses_pd.values.tolist(), lenses_target) #使用数据,构建决策树 dot_data = StringIO() tree.export_graphviz(clf, out_file = dot_data, #绘制决策树 feature_names = lenses_pd.keys(), class_names = clf.classes_, filled=True, rounded=True, special_characters=True) graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) graph.write_pdf("tree.pdf") print(clf.predict([[1,1,1,0]]))#预测
会生成一个树形图pdf如下:
然后进行预测,输入测试数据在最后一行代码,结果为hard,硬材质。
学习总结
- 这节内容概念虽然通俗易懂,但是代码还是存在很大的问题,需要细抠。
- 代码掌握不了就先放一放,先懂算法,代码先看个大概,会用就好。后期做项目再好好看。