统计学习方法笔记 第五章 决策树,CART算法(包含Python代码)

本文深入探讨了决策树模型,包括ID3、C4.5和CART算法。介绍了特征选择、信息增益与信息增益比,以及决策树生成和剪枝过程。通过Python展示了决策树的实现和可视化,适合理解决策树的工作原理和应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


决策树既可以用来做回归也可以用来做分类,主要包括特征选择,决策树生成和决策树的修剪。

1. 决策树模型与学习

1.1 决策树模型与学习

用决策树进行分类时,从根节点开始,对实例的某一个特征进行测试,根据测试结果,将实例分配到其子节点,递归的对于子节点进行测试并进行分配,直至到达叶节点,最后将实例分配到叶节点的类中。

1.2 决策树和if-then规则

决策树可以看成if-then规则的集合。

1.3 决策树与条件概率分布

条件概率分布定义在特征空间的一个划分,将特征空间划分成互不相交的单元区域,并在每一个单元定义一个类的概率分布就构成了一个条件概率分布。

1.4 决策树学习

决策树学习的本质是从训练数据集中归纳出一组分类规则,学习的损失函数通常用正则化的极大似然函数,决策树学习的算法通常是一个递归的选择最优特征,并根据该特征对训练数据进行分割,使得各个数据集有一个最好的分类的过程。


2. 特征选择

2.1 特征选择问题

特征选择在于选择对训练数据具有最好分类能力的特征。通常根据信息增益或者信息增益比来选择。

2.2 信息增益

熵: H ( p ) = − ∑ i = 1 n p i log ⁡ p i H(p)=-\sum_{i=1}^{n}p_i\log{p_i} H(p)=i=1npilogpi
条件熵: H ( Y ∣ X ) = ∑ i = 1 n p i H ( Y ∣ X = x i ) H(Y|X)=\sum_{i=1}^{n}{p_iH(Y|X=x_i)} H(YX)=i=1npiH(YX=xi)
信息增益(互信息): g ( D , A ) = H ( D ) − H ( D ∣ A ) g(D,A)=H(D)-H(D|A) g(D,A)=H(D)H(DA)

算法1: 信息增益计算

  1. 计算数据集 D D D的经验熵 H ( D ) H(D) H(D):
    H ( D ) = − ∑ k = 1 K ∣ C k ∣ ∣ D ∣ log ⁡ 2 ∣ C k ∣ ∣ D ∣ H(D)=-\sum_{k=1}^{K}\frac{|C_k|}{|D|}\log_2{\frac{|C_k|}{|D|}} H(D)=k=1KDCklog2DCk
  2. 计算特征 A A A对数据集 D D D的经验条件熵 H ( D ∣ A ) H(D|A) H(DA):
    H ( D ∣ A ) = − ∑ i = 1 n ∣ D i ∣ ∣ D ∣ ∑ k = 1 K ∣ D i k ∣ ∣ D i ∣ log ⁡ 2 ∣ D i k ∣ ∣ D i ∣ H(D|A)=-\sum_{i=1}^{n}\frac{|D_i|}{|D|}\sum_{k=1}^{K}\frac{|D_{ik}|}{|D_i|}\log_2{\frac{|D_{ik}|}{|D_i|}} H(DA)=i=1nDDik=1KDiDiklog2DiDik
  3. 计算信息增益:
    g ( D , A ) = H ( D ) − H ( D ∣ A ) g(D,A)=H(D)-H(D|A) g(D,A)=H(D)H(DA)

2.3 信息增益比

使用信息熵时,会偏向于选择取值较多的特征,所以选择使用信息熵增益比作为特征选择的另一标准。

信息熵增益比:

g R ( D , A ) = g ( D , A ) − ∑ i = 1 n ∣ D i ∣ ∣ D ∣ log ⁡ 2 ∣ D i ∣ ∣ D ∣ g_R(D,A)=\frac{g(D,A)}{-\sum_{i=1}^{n}\frac{|D_i|}{|D|}\log_2{\frac{|D_i|}{|D|}}} gR(D,A)=i=1nDDilog2DDig(D,A)


3. 决策树生成

3.1 ID3 算法

ID3算法的核心算法是在每一个节点上应用信息增益作为选择特征,递归的构建决策树,由于缺少剪枝步骤,ID3算法产生的树容易产生过拟合。

算法2:ID3算法

  1. 若训练集中的所有实例属于同一类 C k C_k Ck,则 T T T为单节点树,并将类 C k C_k Ck作为该结点的类标记,,返回 T T T
  2. A = ∅ A=\varnothing A=,则 D D D为单节点树,并将 C k C_k Ck中实例树最大的类$ 作 为 该 节 点 的 类 标 记 , 返 回 作为该节点的类标记,返回 T$。
  3. 否则,按照算法1计算 A A A中各特征对 D D D的信息增益,选择信息增益最大的特征 A g A_g Ag
  4. 如果的信息增益小于阈值 ϵ \epsilon ϵ,则置 T T T为单节点树,并将 D D D中实例数最大的类 C k C_k Ck作为该节点的类标记,返回 T T T
  5. 否则对 A g A_g Ag的每一个可能取值 a i a_i ai,依 A g = a i A_g=a_i Ag=ai D D D分割为若干非空子集 D i D_i Di,将实例数最大的类作为标记,构建子结点,由结点和子结点构建树,返回。
  6. 对第 i i i个子结点,以 D i D_i Di为训练集,以 A − A g A-{A_g} AAg为特征集,递归地调用1~5,得到子树,返回 T i T_i Ti

3.2 C4.5 的生成算法

C4.5采用信息增益比来选择特征。

算法2:C4.5算法

  1. 若训练集中的所有实例属于同一类 C k C_k Ck,则 T T T为单节点树,并将类 C k C_k Ck作为该结点的类标记,,返回 T T T
  2. A = ∅ A=\varnothing A=,则 D D D为单节点树,并将 C k C_k Ck中实例树最大的类$ 作 为 该 节 点 的 类 标 记 , 返 回 作为该节点的类标记,返回 T$。
  3. 否则,计算 A A A中各特征对 D D D信息增益比,选择信息增益比最大的特征 A g A_g Ag
  4. 如果的信息增益比小于阈值 ϵ \epsilon ϵ,则置 T T T为单节点树,并将 D D D中实例数最大的类 C k C_k Ck作为该节点的类标记,返回 T T T
  5. 否则对 A g A_g Ag的每一个可能取值 a i a_i ai,依 A g = a i A_g=a_i Ag=ai D D D分割为若干非空子集 D i D_i Di,将实例数最大的类作为标记,构建子结点,由结点和子结点构建树,返回。
  6. 对第 i i i个子结点,以 D i D_i Di为训练集,以 A − A g A-{A_g} AAg为特征集,递归地调用1~5,得到子树,返回 T i T_i Ti

4. 决策树剪枝

决策树的剪枝通过极小化决策树的整体损失函数或代价函数来实现。

损失函数:

C a ( T ) = C ( T ) + α ∣ T ∣ C_a(T)=C(T)+\alpha|T| Ca(T)=C(T)+αT

其中 C ( T ) C(T) C(T)表示训练误差, α \alpha α控制模型的复杂程度。

决策树的生成算法只考虑对训练数据进行拟合,而剪枝通过优化损失函数还考虑了减少模型的复杂程度。

算法3:剪枝算法

  1. 计算每个结点的经验熵。
  2. 递归的从叶结点向上回缩,如果回缩后的损失函数比之前的损失函数要小,则将原来的父节点变成叶结点。
  3. 返回执行2,直到不能继续执行,返回得到的新的子树。

5. CART算法

CART算法由以下两步组成:

  1. 决策树生成:基于训练集生成决策树,生成的决策树要尽量大。
  2. 决策树剪枝:用验证集对已生成的树进行剪枝并选择最优子树,选用损失函数最小作为剪枝的标准。

5.1 CART 生成

对回归树用平方误差最小准则,对分类树用基尼指数最小准则。

5.1.1 回归树生成

算法4:最小二乘回归树生成

  1. 选择最优切分变量 j j j和切分点 s s s,求解: min ⁡ j , s [ min ⁡ c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min ⁡ c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \min_{j,s}[\min_{c_1}\sum_{x_i\in{R_1(j,s)}}(y_i-c_1)^2+\min_{c_2}\sum_{x_i\in{R_2(j,s)}}(y_i-c_2)^2] j,smin[c1minxiR1(j,s)(yic1)2+c2minxiR2(j,s)(yic2)2]遍历变量 j j j,对固定的切分变量 j j j扫描切分点 s s s,选择使上式达到最小值的对 ( j , s ) (j,s) (j,s)
  2. 用选定的对 ( j , s ) (j,s) (j,s)划分区域并决定相应的输出值: R 1 ( j , s ) = x ∣ x ( j ) ≤ s , R 2 ( j , s ) = x ∣ x ( j ) > s R_1(j,s)={x|x^{(j)}\leq{s}},R_2(j,s)={x|x^{(j)}}>s R1(j,s)=xx(j)s,R2(j,s)=xx(j)>s c ^ m = 1 N m ∑ x i ∈ R m ( j , s ) y i , x ∈ R m , m = 1 , 2 \hat{c}_m=\frac{1}{N_m}\sum_{x_i\in{R_m(j,s)}}y_i,x\in{R_m},m=1,2 c^m=Nm1xiRm(j,s)yi,xRm,m=1,2
  3. 继续对两个子区域调用1,2,直到满足停止条件。
  4. 将输入空间划分成M个区域 R 1 , R 2 , … , R M R_1,R_2,\dots,R_M R1,R2,,RM生成决策树: f ( x ) = ∑ m = 1 M c ^ m I ( x ∈ R m ) f(x)=\sum_{m=1}^{M}\hat{c}_mI(x\in{R_m}) f(x)=m=1Mc^mI(xRm)
5.1.2 分类树生成

基尼指数: G i n i ( p ) = 1 − ∑ k = 1 K p k 2 = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 Gini(p)=1-\sum_{k=1}^{K}p_k^2=1-\sum_{k=1}^{K}(\frac{|C_k|}{|D|})^2 Gini(p)=1k=1Kpk2=1k=1K(DCk)2

基尼指数表示集合的不确定性,基尼指数越大,不确定性越大。

算法5:基尼系数分类树生成

  1. 计算训练集的基尼指数,对每一个特征的每一个取值,根据是否将训练集分割成两部分,计算基尼指数。
  2. 对所有的特征的各个取值,选择基尼指数最小的特征和切分点作为最优特征和最优切分点,用最优切分点将训练集切分,将训练集分配到两个子结点中。
  3. 对两个子结点递归地调用1,2,直到满足停止条件。
  4. 生成CART决策树。

5.2 CART 剪枝

算法6:CART树剪枝算法

  1. k = 0 , T = T 0 k=0,T=T_0 k=0,T=T0
  2. α = + ∞ \alpha=+\infty α=+
  3. 自上而下地对各内部结点 t t t计算 C ( T t ) , ∣ T t ∣ C(T_t),|T_t| C(Tt),Tt以及 g ( t ) = C ( t ) − C ( T t ) ∣ T t ∣ − 1 g(t)=\frac{C(t)-C(T_t)}{|T_t|-1} g(t)=Tt1C(t)C(Tt) α = min ⁡ ( α , g ( t ) ) \alpha=\min{(\alpha,g(t))} α=min(α,g(t))
  4. g ( t ) = α g(t)=\alpha g(t)=α地内部结点进行剪枝,并对叶结点以多数表决法决定其类,得到树 T T T
  5. k = k + 1 , α k = α , T k = T k=k+1,\alpha_k=\alpha,T_k=T k=k+1,αk=α,Tk=T
  6. 如果 T k T_k Tk不是由根节点及两个叶结点构成地树,则回到3,否则令 T k = T T_k=T Tk=T
  7. 采用交叉验证法在子树序列 T 0 , T 1 , … , T n T_0,T_1,\dots,T_n T0,T1,,Tn中选取最优子树 T α T_{\alpha} Tα

6. python实现

6.1 决策树实现和可视化

# -*- encoding: utf-8 -*-
'''
@File    :   Chapter3_tree.py
@Contact :   zzldouglas97@gmail.com

@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2019/5/4 12:31   Douglas.Z      1.0         None
'''

# import lib
from math import log
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import operator
import pickle

def calShannonEnt(dataSet):
    """
    计算香农熵
    :param dataSet:
    :return:
    """
    # 计算样本总数
    numEntries = len(dataSet)
    # 生成每个标签的计数
    labelCount = {}
    # 为所有分类创建字典
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCount.keys():
            labelCount[currentLabel] = 0
        labelCount[currentLabel] += 1
    #  计算熵
    shannonEnt = 0.0
    for key in labelCount:
        # p(x)*log(P(x))
        prob = float(labelCount[key])/numEntries
        shannonEnt -= prob*log(prob, 2)
    return shannonEnt

def createDataSet():
    """
    创建测试数据集
    :return:
    """
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def splitDataSet(dataSet, axis, value):
    """
    划分数据集
    :param dataSet:
    :param axis:
    :param value:
    :return:
    """
    returnDat = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            returnDat.append(reducedFeatVec)
    return returnDat

def chooseBestFeatureToSplit(dataSet):
    """
    选择最佳特征
    :param dataSet:
    :return:
    """
    # 计算包含多少特征属性
    numFeatures = len(dataSet[0]) - 1
    # 计算整个数据集的香农熵
    baseEntropy = calShannonEnt(dataSet)
    # 初始化信息增益和最佳特征
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        # 创建唯一的分类标签列表
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 计算每一种划分的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob*calShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        # 计算最好的信息增益
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    """
    决定叶子节点分类
    :param classList:
    :return:
    """
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels, featLabels):
    """
    创建递归树
    :param dataSet:
    :param labels:
    :param featLabels:
    :return:
    """
    classList = [example[-1] for example in dataSet]
    #取分类标签
    if classList.count(classList[0]) == len(classList):
        #如果类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1 or len(labels) == 0:
        #遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    #选择最优特征
    bestFeatLabel = labels[bestFeat]
    #最优特征的标签
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel:{}}
    #根据最优特征的标签生成树
    del(labels[bestFeat])
    #删除已经使用特征标签
    featValues = [example[bestFeat] for example in dataSet]
    #得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)
    #去掉重复的属性值
    for value in uniqueVals:
        #遍历特征,创建决策树。
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    return myTree

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制结点
    :param nodeTxt:
    :param centerPt:
    :param parentPt:
    :param nodeType:
    :return:
    """
    arrow_args = dict(arrowstyle="<-")
    font = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=14)
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)

def getNumLeafs(myTree):
    """
    获得叶子个数
    :param myTree:
    :return:
    """
    numLeafs = 0
    # firstStr = myTree.keys()[0]
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs

def getTreeDepth(myTree):
    """
    获得树的深度
    :param myTree:
    :return:
    """
    maxDepth = 0
    # firstStr = myTree.keys()[0]
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1+ getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTree = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}},1: 'no'}}}}]
    return listOfTree[i]

def plotMidText(cntrPt, parentPt, txtString):
    """
    绘制文本信息
    :param cntrPt:
    :param parentPt:
    :param txtString:
    :return:
    """
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    """
    绘制树
    :param myTree:
    :param parentPt:
    :param nodeTxt:
    :return:
    """
    decisionNode = dict(boxstyle='square', fc="0.8")
    leafNode = dict(boxstyle='round4', fc="0.8")
    arrow_args = dict(arrowstyle="<-")
    numLeafs = getNumLeafs(myTree)
    #获取决策树叶结点数目,决定了树的宽度
    depth = getTreeDepth(myTree)
    #获取决策树层数
    firstStr = next(iter(myTree))
    #下个字典
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)
    #标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    #绘制结点
    secondDict = myTree[firstStr]
    #下一个字典,也就是继续绘制子结点
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    #y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            plotTree(secondDict[key],cntrPt,str(key))
            #不是叶结点,递归调用继续绘制
        else:
            #如果是叶结点,绘制叶结点,并标注有向边属性值
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    """
    可视化决策树
    :param inTree:
    :return:
    """
    fig = plt.figure(1, facecolor='white')
    #创建fig
    fig.clf()
    #清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    #去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))
    #获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))
    #获取决策树层数
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    #x偏移
    plotTree(inTree, (0.5,1.0), '')
    #绘制决策树
    plt.show()

def classify(inputTree, featLabels, testVec):
    """
    决策树分类
    :param inputTree:
    :param featLabels:
    :param testVec:
    :return:
    """
    firstStr = next(iter(inputTree))
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

def storeTree(inputTree, filename):
    """
    存储决策树
    :param inputTree:
    :param filename:
    :return:
    """
    with open(filename, 'wb') as fw:
        pickle.dump(inputTree, fw)

def grabTree(filename):
    """
    提取决策树
    :param filename:
    :return:
    """
    fr = open(filename)
    return pickle.load(fr)

if __name__ == '__main__':
    myDat, labels = createDataSet()
    featLabels = []
    myTree = createTree(myDat, labels, featLabels)
    storeTree(myTree,'tree.txt')
    createPlot(myTree)
    result = classify(myTree, featLabels, [1,0])
    print(result)

6.2 预测隐形眼镜

# -*- encoding: utf-8 -*-
'''
@File    :   Chapter3_lens.py    
@Contact :   zzldouglas97@gmail.com

@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2019/5/4 20:25   Douglas.Z      1.0         None
'''

# import lib
from math import log
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import operator
import pickle
import Chapter3.Chapter3_tree as tree

if __name__ == '__main__':
    fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age','prescript','astigmatic','tearRate']
    featLabels = []
    lensesTree = tree.createTree(lenses, lensesLabels, featLabels)
    tree.createPlot(lensesTree)

6.3 sklearn决策树

# -*- encoding: utf-8 -*-
'''
@File    :   Chapter3_sklearntree.py    
@Contact :   zzldouglas97@gmail.com

@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2019/5/4 20:41   Douglas.Z      1.0         None
'''

# import lib
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.externals.six import StringIO
from sklearn import tree
import pandas as pd
import numpy as np
import pydotplus

if __name__ == '__main__':
    with open('lenses.txt', 'r') as fr:
        #加载文件
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]
        #处理文件
    lenses_target = []
    # 提取每组数据的类别,保存在列表里
    for each in lenses:
        lenses_target.append(each[-1])

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    # 特征标签
    lenses_list = []
    # 保存lenses数据的临时列表
    lenses_dict = {}
    # 保存lenses数据的字典,用于生成pandas
    for each_label in lensesLabels:
        # 提取信息,生成字典
        for each in lenses:
            lenses_list.append(each[lensesLabels.index(each_label)])
        lenses_dict[each_label] = lenses_list
        lenses_list = []
    # print(lenses_dict)
    # #打印字典信息
    lenses_pd = pd.DataFrame(lenses_dict)
    # 生成pandas.DataFrame
    print(lenses_pd)
    # 打印pandas.DataFrame
    le = LabelEncoder()
    # 创建LabelEncoder()对象,用于序列化
    for col in lenses_pd.columns:  # 为每一列序列化
        lenses_pd[col] = le.fit_transform(lenses_pd[col])
    print(lenses_pd)

    clf = tree.DecisionTreeClassifier(max_depth = 4)
    #创建DecisionTreeClassifier()类
    clf = clf.fit(lenses_pd.values.tolist(), lenses_target)
    #使用数据,构建决策树
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file = dot_data,
                         #绘制决策树
                        feature_names = lenses_pd.keys(),
                        class_names = clf.classes_,
                        filled=True, rounded=True,
                        special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("tree.pdf")

7. 参考资料

Classification and Regression Trees


8. 总结

优点:计算复杂度不高,输出结果易于理解,对中间值缺失不敏感,可以处理不相关特征数据。
缺点:可能产生过拟合问题。
适用数据范围:数值型和标称型。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值