手写代码实现基于信息熵划分的决策树算法
1. 简介
阅读本文需要以下背景知识:
-掌握周志华《西瓜书》第四章决策树原理
-Python3.0基础语法及数据类型及操作
不了解决策树请点击下面链接西瓜书第四章决策树学习笔记
本文是基于信息熵准则进行划分选择的决策树算法的手写实现,不使用现有的机器学习包。算法流程见《西瓜书》第四章第一节。数据集使用西瓜数据集3.0(数据集在代码中不需要另外下载),实现语言为Python3.0。代码注解详细,适合新手,欢迎转载
2. 算法实现思路
算法流程是现成的,关键是如何把数据集嵌入到算法中并实现递归,我的思路如下:
对决策树不同功能进行划分,每个功能封装成函数,不同功能的函数有
-def createDataSet() #对数据集进行加工,返回数据集dataSet和特征集labels
-def get_Value(dataSet, labels) #以字典labelsCounts返回数据集dataSet中所有的特征,和对应特征的所有取值
-def calcShannonEnt(dataSet) #计算dataSet的信息熵。返回信息熵数值
-def chooseBestFeatureToSplit(dataSet) #计算出信息增益,选择信息增益最大的特征作为最优划分属性。返回最优属性在特征集labels中的索引
-def splitDataSet(dataSet, bestFeat, value) #由给定的父数据集dataSet,最优特征 bestFeat,和最优特征的取值value(由labelsCounts获得)划分出数据子集,返回数据子集
-def majorityCnt(classList) #输入数据集dataSet的类别标签列classList得到在数据集dataSet中类别最多的样本的类别名(字符串)
-def createTree(dataSet, labels, labelscounts) #这是一个递归函数,输入数据集dataSet,特征集labels和所有特征取值字典labelscounts得到一个具有一层分支的树,要是这层分支中每个子集subdataSet都是叶节点,创建字典,以被划分的最优属性的取值value为键,对应这个取值的叶节点类型为值(叶节点判定标准:集合中样本都相同标签也相同标为叶节点,叶类型为集合中样本标签;集合中样本都相同但是标签不同标为叶节点,叶类型为集合中众数样本类别;集合为空集标为叶结点,叶类别为其父节点众数样本类别)。若这层分支中不全为叶节点,还有内部节点。则对于叶节点,创建字典,以被划分的最优属性的取值为键,对应这个取值的叶节点类型为值。对于内部节点,把这个子集subdataSet作为新的父集,以新父集的划分最优属性键,值是一个字典,并调用函数def createTree(subdataSet, sublabels, labelscounts)完成递归。返回一个以字典形式存储的决策树
-treePlotter.createPlot(desicionTree) #调用库函数将决策树绘出,treePlotter包是自定义包,代码及使用方法见此treePlotter
3.代码如下
#基于ID3算法的信息增益来实现的决策树
#调用库
from math import log
import operator
import treePlotter #自定义包,包和源程序应在同一文件夹,包代码见链接
'''
西瓜数据集3.0,
dataset=[
# 1
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460, '好瓜'],
# 2
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.774, 0.376, '好瓜'],
# 3
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.634, 0.264, '好瓜'],
# 4
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.608, 0.318, '好瓜'],
# 5
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.556, 0.215, '好瓜'],
# 6
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.403, 0.237, '好瓜'],
# 7
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 0.481, 0.149, '好瓜'],
# 8
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 0.437, 0.211, '好瓜'],
# ----------------------------------------------------
# 9
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 0.666, 0.091, '坏瓜'],
# 10
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 0.243, 0.267, '坏瓜'],
# 11
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 0.245, 0.057, '坏瓜'],
# 12
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 0.343, 0.099, '坏瓜'],
# 13
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 0.639, 0.161, '坏瓜'],
# 14
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 0.657, 0.198, '坏瓜'],
# 15
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.360, 0.370, '坏瓜'],
# 16
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 0.593, 0.042, '坏瓜'],
# 17
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 0.719, 0.103, '坏瓜']
]
'''
#导入数据,数据集有八个特征 '色泽', '根蒂', '敲声', '纹理','脐部','触感','密度','含糖率' ,
#其中密度和含糖率是连续值,为了简略程序,我们忽略他们。为接下来要计算它们的信息增益率,来选择节点的构成方式做准备。
def createDataSet():
"""
对数据集进行一定处理,以方便显示,不出现乱码
色泽Color-> 0: 浅白 | 1: 青绿 | 2: 乌黑
根蒂Root-> 0: 硬挺 | 1: 稍蜷 | 2: 蜷缩
敲声Knock-> 0: 清脆 | 1: 浊响 | 2:沉闷
纹理Texture-> 0: 清晰 | 1: 稍糊 | 2:模糊
脐部Umbilical-> 0: 平坦 | 1: 稍凹 | 2: 凹陷
触感Touch-> 0: 硬滑 | 1: 软粘
标签lab->'GoodMalen'| 'BadMalen'
"""
dataSet = [[1, 2, 1, 0, 2, 0, 'GoodMalen'],
[2, 2, 2, 0, 2, 0, 'GoodMalen'],
[2, 2, 1, 0, 2, 0, 'GoodMalen'],
[1, 2, 2, 0, 2, 0, 'GoodMalen'],
[0, 2, 1, 0, 2, 0, 'GoodMalen'],
[1, 1, 1, 0, 1, 1, 'GoodMalen'],
[2, 1, 1, 1, 1, 1, 'GoodMalen'],
[2, 1, 1, 0, 1, 0, 'GoodMalen'],
[2, 1, 2, 1, 1