决策树介绍
什么是决策树呢?它有什么特点呢?
-
决策树是一种树形结构(可以是二叉树或非二叉树),其组成包括结点和有向边
-
而结点有两种类型,分别是内部结点和叶结点
-
每个内部结点表示一个特征属性上的测试,每个叶节点表示一个类别
-
决策过程:从根节点开始一步步走到叶子结点,将叶子节点存放的类别作为决策结果
根据下面例子做决策
-
Jack,年龄24岁,男性,无子女,最终决策:生存
标准衡量-熵
-
熵是表示随机变量不确定性的度量
-
公式:
-
解释:
-
熵:信息熵越大,表示该随机变量的不确定性越高。
-
当p=0或p=1时,H(p)=0,随机变量完全没有不确定性
-
当p=0.5时,H(p)=1,此时随机变量的不确定性最大
-
例子1:A集合[1,1,1,1,1,1,1,1,2,2] B集合[1,2,3,4,5,6,7,8,9,1]
显然A的熵更小,因为A中的种类数量少,所以更稳定
例子2:
同理,图1更稳定,熵更小
信息熵的计算
-
数据集:根据14天的天气因素的打球情况
-
属性id表示每个样本的编号。
-
属性outlook表示户外天气。sunny晴天,overcast阴天,rainy雨天。
-
属性temperature表示温度,hot热,mild温暖,cool冷。
-
属性humidity表示湿度。high高,normal正常。
-
属性windy表示是否有风。not没有,yes有。
-
属性play表示是否出去玩。yes出去玩,no不出去玩。
-
1. 计算变量Play的信息熵
该数据集总样本14个,play变量的取值只能是no或yes
变量play的信息熵计算如下所示。
2. 计算变量Outlook的信息熵
该数据集总样本14个,outlook变量的取值只能是overcast或rainy或sunny。
变量outlook的信息熵计算如下所示
条件熵
条件熵用于表示在已知某一条件下,随机事件的不确定性或信息量。它通常用H(Y|X)表示,表示在已知随机变量X的条件下,随机变量Y的不确定性。数学上,条件熵可以用以下公式来表示:
条件熵的计算
对属性Outlook分析并计算如下。
信息增益
-
g(D,A)表示在条件A下对于目标变量D的信息增益。
-
H(D)表示随机变量D的信息熵。
-
H(D|A)表示在随机变量A条件下对于目标变量D的条件熵。
信息增益的计算
计算g(play,outlook),表示在随机变量outlook条件下对于目标变量play的信息增益,计算步骤如下。
-
完整公式:g(play,outlook)=H(play)-H(play|outlook)
-
首先要计算H(play),计算式如下所示。
-
然后计算H(play|outlook),计算式如下。
-
最后计算g(play,outlook),计算式如下。
算法流程
-
初始化:首先,算法将所有训练样本集放在根节点。
-
特征选择:对于当前节点,计算所有候选特征的信息增益。选择信息增益最大的特征作为当前节点的分裂特征。
-
节点分裂:根据所选特征的每个不同取值,将当前节点划分为多个子节点。每个子节点包含该特征取值下对应的所有样本。
-
递归构建:对于每个子节点,递归地执行步骤2和步骤3,直到满足停止条件(如所有样本属于同一类别、没有更多特征可供选择等)。
-
构建完成:最终,当所有节点都无法再进一步划分时,决策树构建完成。
决策树构建例题
-
数据集采用游玩数据集,由于样本数据较简单,例题并没有考虑设置阈值。
-
初始化,构建根节点。具体构建方法如下图(3-1)所示。
解释:计算出随机变量play的信息熵H(play),再计算出每个特征的条件熵,得出每个特征的信息增益,选择最大的信息增益对应的属性为根节点,然后对根节点分裂,出现3条子枝。
2. 递归构建,构建图3-1的D1。具体构建方法如下图(3-2)所示。
解释:上图是构建图3-1的D1,由于图3-1的D1表示的是数据集D在outlook=rainy的条件下的新数据集,D1数据集中的outlook属性都是rainy,故不需要再计算g(play,outlook)。
3. 递归构建,构建图3-2的D2。具体构建方法如下图(3-3)所示。
解释:图3-3的windy节点构建完毕,递归构建humidity节点,仍按照算法流程计算信息增益。
构建完成。到此决策树已经构造完成。由于所给数据集构造的决策树较简单,相对于其他数据集可能并非如此,在构造复杂的决策树时,对每个子集重复上述方法,直到满足停止条件。
案例代码实现
数据集
先对数据集进行属性标注:
-
年龄:0代表青年,1代表中年,2代表老年;
-
有工作:0代表否,1代表是;
-
有自己的房子:0代表否,1代表是;
-
信贷情况:0代表一般,1代表好,2代表非常好;
-
类别(是否给贷款):no代表否,yes代表是。
主要代码
# -*- coding: UTF-8 -*-
from math import log
import operator
# 函数说明:创建测试数据集
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], # 数据集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况'] # 分类属性
return dataSet, labels # 返回数据集和分类属性
"""
函数说明:计算给定数据集的经验熵(香农熵)
Parameters:
dataSet - 数据集
Returns:
shannonEnt - 经验熵(香农熵)
"""
def calcShannonEnt(dataSet):
numEntires = len(dataSet) # 返回数据集的行数
labelCounts = {} # 保存每个标签(Label)出现次数的字典
for featVec in dataSet: # 对每组特征向量进行统计
currentLabel = featVec[-1] # 提取标签(Label)信息
if currentLabel not in labelCounts.keys(): # 如果标签(Label)没有放入统计次数的字典,添加进去
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # Label计数
shannonEnt = 0.0 # 经验熵(香农熵)
for key in labelCounts: # 计算香农熵
prob = float(labelCounts[key]) / numEntires # 选择该标签(Label)的概率
shannonEnt -= prob * log(prob, 2) # 利用公式计算
return shannonEnt # 返回经验熵(香农熵)
"""
函数说明:按照给定特征划分数据集
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
"""
def splitDataSet(dataSet, axis, value):
retDataSet = [] # 创建返回的数据集列表
for featVec in dataSet: # 遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] # 去掉axis特征
reducedFeatVec.extend(featVec[axis + 1:]) # 将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet # 返回划分后的数据集
"""
函数说明:选择最优特征
Parameters:
dataSet - 数据集
Returns:
bestFeature - 信息增益最大的(最优)特征的索引值
"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 特征数量
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的香农熵
bestInfoGain = 0.0 # 信息增益
bestFeature = -1 # 最优特征的索引值
for i in range(numFeatures): # 遍历所有特征
# 获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) # 创建set集合{},元素不可重复
newEntropy = 0.0 # 经验条件熵
for value in uniqueVals: # 计算信息增益
subDataSet = splitDataSet(dataSet, i, value) # subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) # 计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) # 根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy # 信息增益
print("第%d个特征的增益为%.3f" % (i, infoGain)) # 打印每个特征的信息增益
if (infoGain > bestInfoGain): # 计算信息增益
bestInfoGain = infoGain # 更新信息增益,找到最大的信息增益
bestFeature = i # 记录信息增益最大的特征的索引值
return bestFeature # 返回信息增益最大的特征的索引值
"""
函数说明:统计classList中出现此处最多的元素(类标签)
Parameters:
classList - 类标签列表
Returns:
sortedClassCount[0][0] - 出现此处最多的元素(类标签)
"""
def majorityCnt(classList):
classCount = {}
for vote in classList: # 统计classList中每个元素出现的次数
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 根据字典的值降序排序
return sortedClassCount[0][0] # 返回classList中出现次数最多的元素
"""
函数说明:递归构建决策树
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
Returns:
myTree - 决策树
"""
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet] # 取分类标签(是否放贷:yes or no)
if classList.count(classList[0]) == len(classList): # 如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最优特征
bestFeatLabel = labels[bestFeat] # 最优特征的标签
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel: {}} # 根据最优特征的标签生成树
del (labels[bestFeat]) # 删除已经使用特征标签
featValues = [example[bestFeat] for example in dataSet] # 得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) # 去掉重复的属性值
for value in uniqueVals:
subLabels = labels[:]
# 递归调用函数createTree(),遍历特征,创建决策树。
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels, featLabels)
return myTree
"""
函数说明:使用决策树执行分类
Parameters:
inputTree - 已经生成的决策树
featLabels - 存储选择的最优特征标签
testVec - 测试数据列表,顺序对应最优特征标签
Returns:
classLabel - 分类结果
"""
def classify(inputTree, featLabels, testVec):
firstStr = next(iter(inputTree)) # 获取决策树结点
secondDict = inputTree[firstStr] # 下一个字典
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
print(myTree)
testVec = [0, 1] # 测试数据
result = classify(myTree, featLabels, testVec)
if result == 'yes':
print('放贷')
if result == 'no':
print('不放贷')
画图
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
# 定义文本框和箭头格式
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
# 设置中文字体
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)
"""
函数说明:获取决策树叶子结点的数目
Parameters:
myTree - 决策树
Returns:
numLeafs - 决策树的叶子结点的数目
"""
def getNumLeafs(myTree):
numLeafs = 0 # 初始化叶子
# python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,
# 可以使用list(myTree.keys())[0]
firstStr = next(iter(myTree))
secondDict = myTree[firstStr] # 获取下一组字典
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
"""
函数说明:获取决策树的层数
Parameters:
myTree - 决策树
Returns:
maxDepth - 决策树的层数
"""
def getTreeDepth(myTree):
maxDepth = 0 # 初始化决策树深度
# python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,
# 可以使用list(myTree.keys())[0]
firstStr = next(iter(myTree))
secondDict = myTree[firstStr] # 获取下一个字典
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth # 更新层数
return maxDepth
"""
函数说明:绘制结点
Parameters:
nodeTxt - 结点名
centerPt - 文本位置
parentPt - 标注的箭头位置
nodeType - 结点格式
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
arrow_args = dict(arrowstyle="<-") # 定义箭头格式
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) # 设置中文字体
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', # 绘制结点
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, fontproperties=font)
"""
函数说明:标注有向边属性值
Parameters:
cntrPt、parentPt - 用于计算标注位置
txtString - 标注的内容
"""
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] # 计算标注位置
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
"""
函数说明:绘制决策树
Parameters:
myTree - 决策树(字典)
parentPt - 标注的内容
nodeTxt - 结点名
"""
def plotTree(myTree, parentPt, nodeTxt):
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 设置结点格式
leafNode = dict(boxstyle="round4", fc="0.8") # 设置叶结点格式
numLeafs = getNumLeafs(myTree) # 获取决策树叶结点数目,决定了树的宽度
depth = getTreeDepth(myTree) # 获取决策树层数
firstStr = next(iter(myTree)) # 下个字典
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) # 中心位置
plotMidText(cntrPt, parentPt, nodeTxt) # 标注有向边属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制结点
secondDict = myTree[firstStr] # 下一个字典,也就是继续绘制子结点
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
plotTree(secondDict[key], cntrPt, str(key)) # 不是叶结点,递归调用继续绘制
else: # 如果是叶结点,绘制叶结点,并标注有向边属性值
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
"""
函数说明:创建绘制面板
Parameters:
inTree - 决策树(字典)
"""
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') # 创建fig
fig.clf() # 清空fig
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 去掉x、y轴
plotTree.totalW = float(getNumLeafs(inTree)) # 获取决策树叶结点数目
plotTree.totalD = float(getTreeDepth(inTree)) # 获取决策树层数
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0; # x偏移
plotTree(inTree, (0.5, 1.0), '') # 绘制决策树
plt.show()
if __name__ == '__main__':
mytree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}
createPlot(mytree)
运行结果:
用matplotlib可视化
算法优缺点
优点:
-
算法原理简单,易于理解。
-
能够处理分类任务,并且分类速度快。
-
决策树模型的可解释性强,便于理解和分析。
缺点:
-
只能处理离散值特征,对于连续值特征需要预先处理(如离散化)。
-
对缺失值敏感,需要进行额外的处理(如填充缺失值)。
-
倾向于选择取值较多的特征作为分裂特征,这可能导致模型不够稳定。
-
容易发生过拟合,特别是当决策树过于复杂时。
算法的改进
ID3算法虽然在构建决策树方面取得了成功,但它也有一些局限性,比如倾向于选择具有更多值的特征(称为“数值偏差”),以及在处理连续属性和缺失数据时的不足。为了解决这些问题,研究者们提出了多种改进方法,以下是一些主要的改进方向:
-
C4.5算法:
-
C4.5算法是ID3的一个改进版本,它通过使用信息增益比(Information Gain Ratio)代替纯信息增益来减少对具有更多值的特征的偏见。
-
C4.5还能够处理连续属性和缺失数据,使其适用范围更广。
-
-
CART(Classification and Regression Trees):
-
CART算法是另一种流行的决策树算法,它使用基尼不纯度(Gini Impurity)或信息增益作为分裂标准,并且可以处理分类和回归问题。
-
CART可以处理连续属性,通过在连续属性的中间值进行分裂。
-
-
处理连续属性:
-
连续属性需要在某个点上进行分裂,改进的方法包括使用二分法或其他数值方法来找到最佳分裂点。
-
-
处理缺失数据:
-
改进的算法可以通过估计缺失值、使用期望值或者通过分裂时考虑不同缺失数据的情况来处理缺失数据。
-
-
多目标决策树:
-
在多目标决策树中,考虑多个目标或标准来构建树,这样可以在多个指标上取得平衡。
-
-
使用先验知识:
-
将领域专家的先验知识整合到决策树构建过程中,可以帮助算法更好地捕捉数据中的重要特征。
-
-
集成方法:
-
通过构建多个决策树并使用集成方法(如随机森林)来提高预测的准确性和鲁棒性。
-
-
改进的剪枝技术:
-
为了防止过拟合,改进的剪枝技术可以在构建树之后去除一些不必要的分支,以提高模型的泛化能力。
-
-
模糊决策树:
-
引入模糊逻辑来处理数据中的不确定性和模糊性,使决策树更加灵活和适应性强。
-
-
优化搜索策略:
-
使用更高效的搜索策略,如遗传算法或粒子群优化,来选择最佳的特征和分裂点。
-
应用与展望
ID3已被广泛应用于各个领域,包括但不限于:
-
医疗数据分析:在医疗领域,ID3算法可以帮助分析患者的历史数据,从而预测患者的治疗结果。通过分析各种医疗指标,如年龄、性别、病史等,算法能够辅助医生做出更准确的诊断和治疗方案。
-
教育环境:在教育领域,ID3算法能够预测学生的学习表现,并根据这些预测结果适应性地推荐学习路径。这有助于教育工作者更好地理解学生的学习模式,从而提供个性化的教学支持。
-
农业数据分析:在农业领域,ID3算法被用于预测作物产量和检测植物疾病。通过对气候条件、土壤质量、作物种类等数据的分析,算法可以帮助农民做出更合理的种植决策。
关于ID3算法的论文
-
基于简化信息熵和协调度的ID3算法改进
该论文提出了一个基于简化信息熵和协调度的ID3算法改进方法。通过这种方法,能够更有效地处理具有多值属性的数据集,改善了决策树的结构和运行时间。特别是在大数据样本集上,新算法表现出更优的结构和运行效率。
-
An Improved ID3 Classification Algorithm Based On Correlation Function and Weighted Attribute
这篇文章提出了一种基于相关函数和加权属性的改进ID3分类算法。通过引入加权属性和相关函数优化信息增益的计算,使得在处理具有相关属性的数据集时,能够更准确地选择分裂属性,从而提高决策树的构建效率和分类准确率。
-
Semantic decision Trees: A new learning system for the ID3-Based algorithm using a knowledge base
决策树是分类信息的重要工具,但其效率和准确性受输入数据和构建算法的影响。本研究提出了语义决策树(SDT),基于ID3算法,通过知识库解决多值偏差问题,并采用本体技术优化分裂条件。SDT在四个数据集上的测试显示,其准确性优于传统算法,且结构更符合人类决策逻辑。