使用ID3算法构建的决策树有如下问题:
- 每次选取当前最佳的特征来分割数据,并按照该特征所有可能的取值来切分。也就是说,一个特征有n个取值,那么数据就会被分割成n份。
- 使用某一特征来分割数据后,该特征在之后的算法执行过程中将不会再起作用,这种切分方式过于迅速。
- 不能直接处理连续型特征,只有事先将连续型特征转换成离散型,才能使用ID3算法。
CART算法是一种基于“基尼指数”的决策树构建算法,常用于回归树的构建中。回归树与分类树思想基本一致,但叶结点的数据类型不是离散型而是连续型。
树回归的一般流程:
- 收集数据;
- 准备数据:需要数值型的数据,标称型数据应该映射成二值型数据;
- 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树;
- 训练算法;
- 测试算法;
- 使用算法:用训练出的树做预测;
连续和离散型特征的树的构建
使用字典来存储树的数据结构,该字典包含以下4个元素:
- 待切分的特征
- 待切分的特征值
- 右子树。当不需要切分时,也可以是单个值
- 左子树:与右子树类似
函数createTree()的伪代码如下:
def createTree():
找到最佳的待切分特征:
if 该节点不能再分:
将该节点村委叶节点
执行二元切分
在右子树调用createTree()
在左子树调用createTree()
构建树的代码如下:
def loadDataSet(fileName):
dataMat=[]
fr=open(fileName)
for line in fr.readlines():
curLine=line.strip().split('\t')
fltLine=map(float,curLine) #将每一行映射成浮点数
dataMat.append(fltLine)
return dataMat
"""切分数据集(注意这里书上错了)"""
def binSplitDataSet(dataSet,feature,value):#参数:数据集、待切分的特征、该特征的某个值
#通过数组过滤方式将数据集切分得到两个子集返回
mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:] #选出指定特征feature满足大于特征值value的样本数据
mat1=dataSet[nonzero(dataSet[:, feature]<=value)[0],:] #选出指定特征feature满足小于等于特征值value的样本数据
return mat0,mat1
"""树构建函数"""
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
#leafType是对创建叶结点的引用;errType是对总误差方差计算函数的引用;ops是用户定义的参数构成的元组,用于树构建
feat,val=chooseBestSplit(dataSet,leafType,errType,ops) #根据基尼指数选择最好的特征用于划分
if feat==None:
return val
retTree={}
retTree['spInd']=feat
retTree['spVal']=val
lSet,rSet=binSplitDataSet(dataSet,feat,val)
retTree['left']=createTree(lSet,leafType,errType,ops) #递归构建左子树
retTree['right']=createTree(lSet,leafType,errType,ops) #递归构建右子树
return retTree
choosBestSplit函数(用于选择最好特征分割)的伪代码如下:
def chooseBestSplit():
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
return 最佳切分的特征和阈值
代码实现如下:
"""负责计算目标变量的平方误差"""
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0] #均方差函数var(),因为要返回总方差,故要乘以样本个数
"""负责生成叶结点"""
#当chooseBestSplit函数确定不再对数据进行切分时,将调用该函数来得到叶结点的模型。在回归树中,该模型其实就是目标变量的均值
def regLeaf(dataSet):
return mean(dataSet[:,-1])
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
#tolS和tolN用于控制函数的停止时机
tolS=ops[0] #容许的误差下降值
tolN=ops[1] #切分的最少样本数
#将预测值y(特征值)/分类类别转化成一个列表(dataSet[:,-1].T.tolist()[0])
#set函数将这个列表转化成集合,即特征值不同的才会被放入集合
#len计算集合长度,如果为1说明不同剩余特征值的数目为1,那么就不需要在切分,只要直接返回
if len(set(dataSet[:,-1].T.tolist()[0]))==1:
return None,leafType(dataSet) #用leafType对数据集生成叶结点
m,n=shape(dataSet) #n是特征数和y的和
S=errType(dataSet)
bestS=inf
bestIndex=0
bestValue=0
#在所有可能的特征及其可能的取值上遍历
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0,mat1=binSplitDataSet(dataSet,featIndex,splitVal) #切分数据集
if(shape(mat0)[0]<tolN)or(shape(mat1)[0]<tolN): #判断是否还需继续切分
continue
newS=errType(mat0)+errType(mat1)
if newS<bestS:
bestIndex=featIndex
bestValue=splitVal
bestS=newS
if (S-bestS) < tolS:
return None,leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #如果提前终止条件均不满足则返回切分特征和特征值
return None,leafType(dataSet)
return bestIndex,bestValue
if __name__=='__main__':
myDat=loadDataSet('ex00.txt')
myMat=mat(myDat)
regTree=createTree(myMat)
print(regTree)
这里要注意python3中map返回的的类型已经不是list而是可迭代对象,故要将map返回的对象做list处理才能使用。
