这篇文章主要是带来机器学习西瓜书决策书这一章的编程习题。相比机器学习实战中的对应章节有了一定的难度上的提升,主要体现在数据集中加入了连续值,对于连续值的处理不能够和离散值同等对待,否则其不同值各自分为一类显然信息增益最大,但这样在实际的应用中并没有意义甚至适得其反。这就涉及到了对于连续值的处理。
连续值处理
最简单的策略就是采用二分法对于连续值进行处理,这正是C4.5决策树算法中采用的机制。
习题4.3
数据集:
青绿,蜷缩,浊响,清晰,凹陷,硬滑,0.697,0.46,好瓜
乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,0.744,0.376,好瓜
乌黑,蜷缩,浊响,清晰,凹陷,硬滑,0.634,0.264,好瓜
青绿,蜷缩,沉闷,清晰,凹陷,硬滑,0.608,0.318,好瓜
浅白,蜷缩,浊响,清晰,凹陷,硬滑,0.556,0.215,好瓜
青绿,稍蜷,浊响,清晰,稍凹,软粘,0.403,0.237,好瓜
乌黑,稍蜷,浊响,稍糊,稍凹,软粘,0.481,0.149,好瓜
乌黑,稍蜷,浊响,清晰,稍凹,硬滑,0.437,0.211,好瓜
乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,0.666,0.091,坏瓜
青绿,硬挺,清脆,清晰,平坦,软粘,0.243,0.267,坏瓜
浅白,硬挺,清脆,模糊,平坦,硬滑,0.245,0.057,坏瓜
浅白,蜷缩,浊响,模糊,平坦,软粘,0.343,0.099,坏瓜
青绿,稍蜷,浊响,稍糊,凹陷,硬滑,0.639,0.161,坏瓜
浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,0.657,0.198,坏瓜
乌黑,稍蜷,浊响,清晰,稍凹,软粘,0.36,0.37,坏瓜
浅白,蜷缩,浊响,模糊,平坦,硬滑,0.593,0.042,坏瓜
青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,0.719,0.103,坏瓜
对应的代码:
import math
import matplotlib
import matplotlib.pyplot as plt
from numpy import *
def getDataSet():
with open('ex4-3.txt','r') as f:
lines=f.readlines()
DataSet=[]
LabelSet=[]
index=0
for i in lines:
LabelSet.append(i.strip().split(',')[-1])
temp=[];temp.extend(i.strip().split(',')[:-3]);temp.append(float(i.strip().split(',')[-3]));temp.append(float(i.strip().split(',')[-2]));temp.append(i.strip().split(',')[-1])
DataSet.append(temp)
index+=1
return DataSet,array(LabelSet)
def getEntropy(dataSet):
nums=len(dataSet)
labelCounts={}
for featvect in dataSet:
curLabel=featvect[-1]
if curLabel not in labelCounts.keys():
labelCounts[curLabel]=0
labelCounts[curLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/nums
shannonEnt-=prob*math.log(prob,2)
return shannonEnt
# 按照某属性的某个值划分,并去掉该属性
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
temp=featVec[:axis]
temp.extend(featVec[axis+1:])
retDataSet.append(temp)
return retDataSet
def getEntropyForFloat(dataSet,axis,Ta):
newEntropy=inf
threshValue=0.0
for t in Ta:
D0=[];D1=[]
for i in range(len(dataSet)):
if dataSet[i][axis]<=t:
D0.append(dataSet[i])
else:
D1.append(dataSet[i])
tempEntropy=float(len(D0))/len(dataSet)*getEntropy(D0)+float(len(D1))/len(dataSet)*getEntropy(D1)
if tempEntropy<newEntropy:
newEntropy=tempEntropy
threshValue=t
return newEntropy,threshValue
def chooseBestFeatToSplit(dataSet):
numFeatures=len(dataSet[0])-1
baseEntropy=getEntropy(dataSet)
bestInforGain=0.0
bestFeature=-1
retBool=False
retValue=0.0
for i in range(numFeatures):
isFloat=False
threshValue=0.0
newEntropy=0.0
# 区分离散值和连续值
if isinstance(dataSet[0][i],str):
featList=[example[i] for example in dataSet]
uniqueVals=set(featList)
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*getEntropy(subDataSet)
else:
isFloat=True
featList=[example[i] for example in dataSet]
sortedList=sort(featList)
Ta=[]
for k in range(len(sortedList)-1):
Ta.append(float(sortedList[k]+sortedList[k+1])/2)
newEntropy,threshValue=getEntropyForFloat(dataSet,i,Ta)
infoGain=baseEntropy-newEntropy
if infoGain>bestInforGain:
bestInforGain=infoGain
bestFeature=i
retBool=isFloat
retValue=threshValue
return bestFeature,retBool,retValue
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classList.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet)==1:
return majorityCnt(classList)
bestFeat,isFloat,threshValue=chooseBestFeatToSplit(dataSet)
if not isFloat:
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueSet=set(featValues)
for value in uniqueSet:
subLabels=labels[:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
else:
bestFeatLabel=labels[bestFeat]+'<='+str(threshValue)
myTree={bestFeatLabel:{}}
subSet0=[];subSet1=[]
for k in range(len(dataSet)):
if(dataSet[k][bestFeat]<=threshValue):
subSet0.append(dataSet[k])
else:
subSet1.append(dataSet[k])
subLabels=labels[:]
myTree[bestFeatLabel]['是']=createTree(subSet0,subLabels)
subLabels=labels[:]
myTree[bestFeatLabel]['否']=createTree(subSet1,subLabels)
return myTree
if __name__=='__main__':
dataSet,label=getDataSet()
featureLabels=['色泽','根蒂','敲声','纹理','脐部','触感','密度','含糖率']
print(createTree(dataSet,featureLabels))