决策树。后面的CART会涉及到剪枝和回归。那个才是重点。
这里就简单贴一下代码。
from math import log
import operator
def createDataset():
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calEnt(dataset):
# numEnt = dataset.shape[0]
numEnt = len(dataset)
labelCounts = {}
for featureVec in dataset:
curLabel = featureVec[-1]
labelCounts[curLabel] = labelCounts.get(curLabel,0) + 1
ent = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEnt
ent -= prob * log(prob, 2)
return ent
def splitDataset(dataset, axis, value):
retDataset = []
for feavec in dataset:
if feavec[axis] == value:#split the dataset by the different values of the feature
reduceFeavec = feavec[:axis]
reduceFeavec.extend(feavec[axis+1:])
retDataset.append(reduceFeavec)
return retDataset
def chooseBestFeatureToSplit(dataset):
numFeatures = len(dataset[0]) - 1
baseEnt = calEnt(dataset)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
fealist = [example[i] for example in dataset]
uniqueVal = set(fealist)
newEnt = 0.0
for value in uniqueVal:
subDataset = splitDataset(dataset, i, value)
prob = len(subDataset)/float(len(dataset))
newEnt += prob * calEnt(subDataset)
infoGain = baseEnt - newEnt
if infoGain>bestInfoGain:
bestInfoGain = infoGain
bestFeature = i # index of feature
return bestFeature
def majorityCnt(classlist):
classCount = {}
for vote in classlist:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):#all the data belong to one type
return classList[0]
if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
return majorityCnt(classList) # why not judge the capacity of labels
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat]) # delete the value of best feature
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataset(dataSet, bestFeat, value), subLabels)
return myTree
if __name__ == '__main__':
a,b = createDataset()
mytree = createTree(a,b)
print mytree
一些注解:
递归中最重要的无非就是递归边界和递归关系式。
递归边界就是如果某个集合同属于一类,不再分割;或者是某个集合已经没有属性可供分类,不再分割。但此时需要注意的是,我们需要通过投票去获得这个集合的分类结果。
以及python中列表是引用的操作。所以每次递归前要copy一下。
本文介绍了一个简单的决策树算法实现过程,包括数据集的创建、熵的计算、划分数据集、选择最佳特征进行划分等关键步骤,并最终构建了决策树。
419

被折叠的 条评论
为什么被折叠?



