理论
寻找最具影响力的特征先进行判别,像一棵树一样的判断分支再判断分支,知道最后判别出属于哪个类别
优点
- 计算复杂度不高
- 输出结果易于理解,可以看出内在含义
- 对缺失值不敏感
- 可以处理不相关特征数据
缺点
- 易产生过拟合问题
适用于
- 离散型数据
- 连续型数据需要离散化
总结:
划分数据集-按照信息增益
信息增益最高的特征就是最好的选择
信息增益:划分数据集之前和之后发生的信息变化
度量信息的单位称为熵(entropy),即信息的期望值,用于度量集合的无序程度(另一方法:基尼不纯度)
信息熵公式:
H
=
−
∑
i
=
1
n
p
(
x
i
)
log
2
p
(
x
i
)
H = -\sum_{i=1}^{n}p(x_{i})\log_{2}p(x_{i})
H=−i=1∑np(xi)log2p(xi)
其中 p ( x i ) p(x_{i}) p(xi) 是选择该分类的概率, n n n是分类的数目
例子
dataSet = 有两个分类yes和no ,共5 个样本,其中yes 2个,no 3个
H = -(0.4log(0.4,2)+0.6log(0.6,2)) = 0.971
from math import log
def calcShannonEnt(dataSet):
'''
计算信息熵
'''
numEntries = len(dataSet)
labelCounts = {} # 放各类别概率
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labeCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
# labelCounts[currentLabel] =labelCounts.get(currentLabel, 0) + 1
shannonEnt = 0.0
for key in labeCounts:
prob = float(labelCounts[key])/numEntries # 每个类别出现的概率
shannonEnt -= prob*log(prob, 2)
return shannonEnt
信息增益
特征A对训练集数据D的信息增益
g
(
D
,
A
)
g(D, A)
g(D,A),定义为集合D的信息熵
H
(
D
)
H(D)
H(D)与特征A给定条件下D的信息熵
H
(
D
∣
A
)
H(D|A)
H(D∣A)之差,即公式为:
g
(
D
,
A
)
=
H
(
D
)
−
H
(
D
∣
A
)
g(D,A) = H(D) - H(D|A)
g(D,A)=H(D)−H(D∣A)
信息熵和信息增益的计算
按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
'''
dataSet : 数据集
axis : 特征的索引
value : 特征的值
按照给定特征划分数据集
取出每一个axis位置上为value的样本
'''
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
# list.extend和list.append 的功能类似,anppend是以整个为一个元素添加,extend是可以以其中每个元素为元素添加
retDataSet.append(reducedFeatVec)
return retDataSet
选取“最优”特征,划分数据集
def chooseBestFeatureToSplit(dataSet):
'''
选取“最优”特征,划分数据集
dataSet : 训练数据集,二维列表构成,每一行的最后一个元素为标签
返回最优特征的索引
'''
numFeatures = len(dataSet[0]) - 1 # 获取特征数量
baseEntropy = calcShannonEnt(dataSet) # 计算所有样本信息熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # 提取所有样本的某一特征
uniqueVals = set(featList) # 变成集合,去重
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i , value) # 划分样本
prob = len(subDataSet)/float(len(dataSet)) # 计算概率
newEntropy += prob * calcShannonEnt(subDataSet) # H(D|A)
infoGain = baseEntropy - newEntropy
# 最大信息增益的特征为最好
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
创建树
def majorityCnt(classList):
'''
返回数据集中概率最高的标签
'''
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
'''
递归构建决策树
dataSet : 数据集
labels : 标签列表
'''
classList = [example[-1] for example in dataSet]
# 分割后的数据集内只有一种类别时,结束递归,获取该类别
if classList.count(classList[0]) == len(classList):
return classList[0]
# 分割后的数据集中只有一个特征时,结束递归,获取该数据集中概率最高的目标向量
if len(dataSet[0]) == 1 :
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最优特征
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}} # 注意,这里创建字典的时候需要嵌套一个
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet] # 该最优标签的值
uniqueVals = set(featValues) # 标签的所有可能性
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
# 调用
mydata = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels= ['no surfacing', 'flippers']
myTree = createTree(mydata, labels)
myTree
运行结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
变量myTree中包含了很多代表树结构信息的嵌套字典,从左边开始是第一个划分数据集的特征,他的值也是一个字典,当值为类标签时,就说明到了叶子节点
Matpolotlib 绘制树形图
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 箭头指向 <- 指向节点 ;-> 指向说明
def plotNode(noteTxt, centerPt,parentPt, nodeType):
'''
绘制子节点和父节点
'''
createPlot.axl.annotate(noteTxt, xy=parentPt, xycoords="axes fraction", xytext=centerPt, textcoords="axes fraction", va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
'''
测量叶子节点数
'''
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
'''
测量树的深度
'''
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth :
maxDepth = thisDepth
return maxDepth
def plotMidText(cntrPt, parentPt, txtString):
'''
在节点间填充文本信息
'''
xMid = (parentPt[0] - cntrPt[0])/2.0 +cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 +cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xoff+(1.0 + float(numLeafs)) / 2.0/plotTree.totalW, plotTree.yoff)
plotMidText(cntrPt, parentPt, nodeTxt) # 线之间的文字信息
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 节点和箭头
secondDict = myTree[firstStr]
plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD # y坐标下移一格
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
def createPlot(inTree):
'''
绘制的准备工作,画布、轴
'''
fig = plt.figure(1, facecolor="white")
fig.clf() # Clear figure清除所有轴,但是窗口打开,这样它可以被重复使用。
axprops = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree)) # 全局变量存放宽度,整个树的叶子节点数
plotTree.totalD = float(getTreeDepth(inTree)) # 深度
plotTree.xoff = -0.5/plotTree.totalW # 假设有n个叶子结点,那么应该将宽度分为2*n份,起始位置为左移半格
plotTree.yoff = 1.0
plotTree(inTree, (0.5,1.0), '') # 设整个图为1*1, 那么根节点应该在0.5,1的位置
plt.show()
# 调用
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
createPlot(myTree)
运行结果:
初始的plotTree.xoff:
def classify(inputTree, featLabels, testVec):
'''
使用决策树分类
'''
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == "dict":
classLable = classify(secondDict[key], featLabels, testVec)
else :
classLable = secondDict[key]
return classLable
# 调用
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
labels = ['no surfacing','flippers']
classify(myTree, labels, [1,0])
classify(myTree, labels, [1,1])
运行结果
'no'
'yes'
存储、读取决策树
def storeTree(inputTree, filename):
'''
保存决策树
'''
import pickle
fw = open(filename,'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
'''
读取决策树
'''
import pickle
fr = open(filename)
return pickle.load(fr)