# -*- coding: utf-8 -*-
"""
Created on Sat Aug 24 11:14:58 2019
@author:wangtao_zuel
E-mail:wangtao_zuel@126.com
决策树CART方法
"""
import numpy as np
import pandas as pd
def loadData(filepath,fileType):
"""
返回矩阵形式的数据,当样本类别为str类型时,应当相应修改样本读取方式
"""
if fileType == 'xlsx':
data = pd.read_excel(filepath)
elif fileType == 'csv':
data = pd.read_csv(filepath)
else:
data = pd.read_csv(filepath,sep='\t',header=None)
data = np.mat(data)
return data
def binSplitDataSet(dataSet,featInd,featVal):
"""
按照特征(序号)、特征值将样本二分,这里统一将小于的部分放在左边(matL)
"""
matL = dataSet[np.nonzero(dataSet[:,featInd] <= featVal)[0],:]
matR = dataSet[np.nonzero(dataSet[:,featInd] > featVal)[0],:]
return matL,matR
def regLeaf(dataSet):
"""
叶节点创建
这里返回分支下的分类平均值,适用于回归情况
"""
return np.mean(dataSet[:,-1])
def maxLeaf(dataSet):
"""
叶节点创建
这类返回最多的分类
"""
results = uniqueCount(dataSet)
return max(results,key=results.get)
def uniqueCount(dataMat):
"""
统计各类别样本个数
注意这里使用的是矩阵类数据,若使用其他类型数据需修改遍历循环部分“dataSet[:,-1].T.tolist()[0]”
"""
results = {}
for sample in dataMat[:,-1].T.tolist()[0]:
if sample not in results:
results[sample] = 0
results[sample] += 1
return results
def regErr(dataSet):
"""
误差计算
这里使用的是平方误差,适合回归情况
"""
var = np.var(dataSet[:,-1])
m = dataSet.shape[0]
err = m*var
return err
def entErr(dataSet):
"""
香农熵计算误差(混乱程度)
"""
results = uniqueCount(dataSet)
sampleNum = dataSet.shape[0]
shannonEnt = 0.0
for key in results:
prob = float(results[key])/sampleNum
shannonEnt -= prob*np.log2(prob)
return shannonEnt
def giniErr(dataSet):
"""
基尼不纯度计算误差(混乱程度)
"""
sampleNum = dataSet.shape[0]
results = uniqueCount(dataSet)
imp = 0.0
for k1 in results:
p1 = float(results[k1])/sampleNum
for k2 in results:
if k1 == k2:
continue
p2 = float(results[k2])/sampleNum
imp += p1*p2
return imp
def chooseBestSplit(dataSet,leafType,errType,ops):
"""
筛选最优分类特征、特征值
"""
# 预剪枝参数,当优化(误差减小)过小或者分类太细(分支下样本数量太少),选择忽略
minErr = ops[0]
minNum = ops[1]
# 若某分支下样本均为同一类,则返回建立叶节点
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None,leafType(dataSet)
m,n = dataSet.shape
# 不分类误差
basicErr = errType(dataSet)
bestErr = np.inf
bestInd = 0
bestVal = 0
# 获取最小误差
for featInd in range(n-1):
for featVal in set(dataSet[:,featInd].T.tolist()[0]):
matL,matR = binSplitDataSet(dataSet,featInd,featVal)
# 判断分支下样本数目是否过小,预剪枝的一部分
if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):
continue
newErr = errType(matL) + errType(matR)
if newErr < basicErr:
bestInd = featInd
bestVal = featVal
bestErr = newErr
# 若优化太小,分类和不分类相差不大,则忽略优化,其实这部分也是预剪枝的一部分,
if (basicErr - bestErr) < minErr:
return None,leafType(dataSet)
# 二次判断,和前面的部分并未冲突,这部分用于处理没有最优分类特征、特征值的情况
matL,matR = binSplitDataSet(dataSet,bestInd,bestVal)
if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):
return None,leafType(dataSet)
return bestInd,bestVal
def creatTree(dataSet,leafType,errType,ops):
"""
递归创建树
"""
# 选择最优的分类特征、特征值
spInd,spVal = chooseBestSplit(dataSet,leafType,errType,ops)
# 创建叶节点情况
if spInd == None:
return spVal
# 创建子树
tree = {}
tree['spInd'] = spInd
tree['spVal'] = spVal
# 递归得到子分支树
matL,matR = binSplitDataSet(dataSet,spInd,spVal)
tree['left'] = creatTree(matL,leafType,errType,ops)
tree['right'] = creatTree(matR,leafType,errType,ops)
return tree
"""
# 后剪枝操作
"""
def isTree(obj):
"""
判断分支下是否为子树,是则返回True
"""
return (type(obj).__name__=='dict')
def getMean(tree):
"""
塌陷处理,返回左右分支的平均值作为上一节点的值
"""
if isTree(tree['left']):
return getMean(tree['left'])
if isTree(tree['right']):
return getMean(tree['right'])
return (tree['left']+tree['right'])/2
def regPrune(tree,testData):
"""
递归后剪枝,需要一定数量的测试集,最好数量和样本集相同
注意这种剪枝方法适合用于结果是连续型数据(按平均值塌陷不太适合分类,因为类别是固定的)
"""
# 若无测试集,则做塌陷处理
if testData.shape[0] == 0:
return getMean(tree)
# 判断节点下是否为子树,若为子树则进一步细分处理,直至节点下均为叶节点
if (isTree(tree['left'])) or (isTree(tree['right'])):
lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
if isTree(tree['left']):
tree['left'] = regPrune(tree['left'],lSet)
if isTree(tree['right']):
tree['right'] = regPrune(tree['right'],rSet)
# 当节点下都为叶节点时,判断是否进行合并处理
if (not isTree(tree['left'])) and (not isTree(tree['right'])):
lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
# 计算未合并时的误差(混乱程度)
notMergeErr = sum(np.power(lSet[:,-1]-tree['left'],2)) + sum(np.power(rSet[:,-1]-tree['right'],2))
treeMerge = (tree['left']+tree['right'])/2
mergeErr = sum(np.power(testData[:,-1]-treeMerge,2))
if mergeErr < notMergeErr:
print("Merging!")
return treeMerge
else:
return tree
# 若节点下不全为叶节点,则不执行合并剪枝操作
else:
return tree
def outJudge(dataSet,tree):
"""
遍历判断样本外数据类型
"""
outputData = pd.DataFrame(dataSet)
classResults = []
for ii in range(dataSet.shape[0]):
result = judgeType(dataSet[ii,:].A[0],tree)
classResults.append(result)
outputData['classResults'] = classResults
outputData.to_excel('./data/machine_learning/mytree.xlsx',index=False,encoding='utf-8-sig')
print("样本外数据分类(判断)完成!")
def judgeType(data,tree):
"""
递归判断分类
"""
spInd = tree['spInd']
spVal = tree['spVal']
if data[spInd] <= spVal:
# 若节点下为子树则递归,否则返回叶节点的值
if isTree(tree['left']):
return judgeType(data,tree['left'])
return tree['left']
else:
if isTree(tree['right']):
return judgeType(data,tree['right'])
return tree['right']
def treeCart(trainDataPath,outDataPath='',testDataPath='',leafType=regLeaf,errType=regErr,ops=(1,4),prune=False,fileType='txt'):
"""
主函数,参数含义:
trainDataPath:训练集数据路径
outDataPath:样本外数据路径
testDataPath:测试集数据路径,当需要后剪枝操作时需输入
leafType:创建叶节点方式
errType:误差(混乱程度)计算方式
ops:预剪枝参数,第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支
prune:是否进行后剪枝操作
fileType:训练集、测试集数据类型(xlsx、txt、csv),txt文件需以制表符\t为分割
"""
dataMat = loadData(trainDataPath,fileType)
try:
myTree = creatTree(dataMat,leafType,errType,ops)
if prune:
testData = loadData(testDataPath,fileType)
myTree = regPrune(myTree,testData)
print('决策树构建完成!')
print(myTree)
else:
print('决策树构建完成!')
print(myTree)
# 预测(分类操作)
if outDataPath != '':
outData = loadData(outDataPath,fileType)
outJudge(outData,myTree)
except:
print("检查是否正确输入参数!")
print('请在函数treeCart中输入叶节点创建方式参数:\n\t1、按平均值创建:leafType=regLeaf\n\t2、按最多样本创建:leafType=maxLeaf')
print('请在treeCart中输入误差计算方式参数:\n\t1、香农熵:errType=entErr\n\t2、基尼不纯度:errType=giniErr\n\t3、平方误差:regErr')
print('请在treeCart中输入预剪枝参数ops:\n\t其中第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支')
print('示例:treeCart(trainDataPath,leafType=regLeaf,errType=regErr,ops=(1,4))')
机器学习笔记——决策树(CART方法)
最新推荐文章于 2025-02-12 10:36:14 发布