决策树属于传统监督学习中的分类算法,现如今的很多专家系统亦采用决策树算法实现。其根本原理是分类使信息增益最大化,也即熵最小化。计算熵的公式为:
p(x)为某分类出现的概率。举个极端的例子,数据集里所有数据都属于同一分类A,那么p(A)=1,所以熵为0。
如果有这样一个数据集 [A, A, A, B], 那么其熵为: -(3/4log(3/4, 2)+1/4log(1/4, 2))=0.8113。假设依然是四个样本,但是多一个分类,具体为 [A, A, B, C], 可计算其熵为:-(1/2log(1/2, 2)+1/4log(1/4, 2)*2)=1.5。熵变大了,这是因为分类变多了,数据更加地无序。
因此决策树的原理,就是每一次寻找一个能使得熵最小的特征,依次往下进行,得到一个树形的分类器。下附《机器学习实战》中根据两个特征判断鱼的代码,直接copy & paste and run 体验一下吧。
'''
code for descision treeee
'''
from math import log
import operator
class DescisionTree(object):
def __init__(self):
pass
def createDataset(self):
dataSet = [[1, 1, "yes"],
[1, 1, "yes"],
[1, 0, "no"],
[0, 1, "no"],
[0, 1, "no"]]
labels = ["No surfacing", "flippers"]
return dataSet, labels
def calcShannonEnt(self, dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*log(prob, 2)
return shannonEnt
def splitDataSet(self, dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(self, dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = self.calcShannonEnt(dataSet)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0
for value in uniqueVals:
subDataSet = self.splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*self.calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(self, classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(self, dataSet, labels):
classList = [example[-1] for example in dataSet]
### only 1 class for all dataset ###
if classList.count(classList[0]) == len(classList):
return classList[0]
### used all features, still have multiple labels ###
if len(dataSet[0]) == 1:
return self.majorityCnt(classList)
bestFeat = self.chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
subDataSet = self.splitDataSet(dataSet, bestFeat, value)
myTree[bestFeatLabel][value] = self.createTree(subDataSet, subLabels)
return myTree
if __name__ == "__main__":
dataset, label = DescisionTree().createDataset()
print(DescisionTree().createTree(dataset, label))