决策树挑好西瓜

本文深入探讨了决策树算法,重点介绍了ID3和CART的实现。ID3算法通过计算信息增益来选择最优划分属性,而CART算法则通过连续或非连续变量进行空间划分。文章通过实例讲解了这两个算法的工作原理,并提供了代码实现和总结。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、ID3算法

1.引包

import numpy as np
import pandas as pd
import sklearn.tree as st
import math
import matplotlib
import os
import matplotlib.pyplot as plt

2.读取数据

data = pd.read_csv('C:/西瓜数据集.csv',header=None)
data

在这里插入图片描述

3.代码编写

def calcEntropy(dataSet):
    mD = len(dataSet)
    dataLabelList = [x[-1] for x in dataSet]
    dataLabelSet = set(dataLabelList)
    ent = 0
    for label in dataLabelSet:
        mDv = dataLabelList.count(label)
        prop = float(mDv) / mD
        ent = ent - prop * np.math.log(prop, 2)

    return ent

拆分数据集

def splitDataSet(dataSet, index, feature):
    splitedDataSet = []
    mD = len(dataSet)
    for data in dataSet:
        if(data[index] == feature):
            sliceTmp = data[:index]
            sliceTmp.extend(data[index + 1:])
            splitedDataSet.append(sliceTmp)
    return splitedDataSet

最优特征

def chooseBestFeature(dataSet):
    entD = calcEntropy(dataSet)
    mD = len(dataSet)
    featureNumber = len(dataSet[0]) - 1
    maxGain = -100
    maxIndex = -1
    for i in range(featureNumber):
        entDCopy = entD
        featureI = [x[i] for x in dataSet]
        featureSet = set(featureI)
        for feature in featureSet:
            splitedDataSet = splitDataSet(dataSet, i, feature)  # 拆分数据集
            mDv = len(splitedDataSet)
            entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
        if(maxIndex == -1):
            maxGain = entDCopy
            maxIndex = i
        elif(maxGain < entDCopy):
            maxGain = entDCopy
            maxIndex = i

    return maxIndex

寻找最多为标签

def mainLabel(labelList):
    labelRec = labelList[0]
    maxLabelCount = -1
    labelSet = set(labelList)
    for label in labelSet:
        if(labelList.count(label) > maxLabelCount):
            maxLabelCount = labelList.count(label)
            labelRec = label
    return labelRec

def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
    labelList = [x[-1] for x in dataSet]
    if(len(dataSet) == 0):
        return mainLabel(labelListParent)
    elif(len(dataSet[0]) == 1): #没有可划分的属性了
        return mainLabel(labelList)  #选出最多的label作为该数据集的标签
    elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Label
        return labelList[0]

    bestFeatureIndex = chooseBestFeature(dataSet)
    bestFeatureName = featureNames.pop(bestFeatureIndex)
    myTree = {bestFeatureName: {}}
    featureList = featureNamesSet.pop(bestFeatureIndex)
    featureSet = set(featureList)
    for feature in featureSet:
        featureNamesNext = featureNames[:]
        featureNamesSetNext = featureNamesSet[:][:]
        splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
        myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
    return myTree

画图

def readWatermelonDataSet():
    dataSet = data.values.tolist()
    featureNames =['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
    #获取featureNamesSet
    featureNamesSet = []
    for i in range(len(dataSet[0]) - 1):
        col = [x[i] for x in dataSet]
        colSet = set(col)
        featureNamesSet.append(list(colSet))
    
    return dataSet, featureNames, featureNamesSet


matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']


decisionNode = dict(boxstyle="sawtooth", fc="0.8")


leafNode = dict(boxstyle="round4", fc="0.8")


arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0

    firstStr = list(myTree.keys())[0]

    secondDict = myTree[firstStr]

    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0

    firstStr = list(myTree.keys())[0]

    secondDic = myTree[firstStr]

    for key in secondDic.keys():
        if type(secondDic[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDic[key])

        else:
            thisDepth = 1

        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
   
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
   
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    
    numLeafs = getNumLeafs(myTree=myTree)

    
    depth = getTreeDepth(myTree=myTree)

    
    firstStr = list(myTree.keys())[0]

    
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

    
    plotMidText(cntrPt, parentPt, nodeTxt)

    
    plotNode(firstStr, cntrPt, parentPt, decisionNode)

    
    secondDict = myTree[firstStr]

   
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

   
    for key in secondDict.keys():
       
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
           
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
           
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

   
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

dataSet, featureNames, featureNamesSet=readWatermelonDataSet()
testTree= createFullDecisionTree(dataSet, featureNames, featureNamesSet,featureNames)
createPlot(testTree)

结果
在这里插入图片描述

二、sklearn实现ID3、CART算法实现

1.ID3

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier

data = pd.read_csv('C:/西瓜数据集.csv',header=None)
data

label = LabelEncoder()    

for col in data[data.columns[:-1]]:
    data[col] = label.fit_transform(data[col])
data

# 采用ID3拟合
dtc = DecisionTreeClassifier(criterion='entropy')
# 进行拟合
dtc.fit(data.iloc[:,:-1].values.tolist(),data.iloc[:,-1].values) 
# 标签对应编码
result = dtc.predict([[0,0,0,0,0,0]])
#拟合结果
result

2.CART

# 采用CART拟合
dtc = DecisionTreeClassifier()
# 进行拟合
dtc.fit(data.iloc[:,:-1].values.tolist(),data.iloc[:,-1].values) 
# 标签对应编码
result = dtc.predict([[0,0,0,1,0,0]])
#拟合结果
result

三、总结

1.ID3算法

ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。
决策树是对数据进行分类,以此达到预测的目的。该决策树方法先根据训练集数据形成决策树,如果该树不能对所有对象给出正确的分类,那么选择一些例外加入到训练集数据中,重复该过程一直到形成正确的决策集。决策树代表着决策集的树形结构。
决策树由决策结点、分支和叶子组成。决策树中最上面的结点为根结点,每个分支是一个新的决策结点,或者是树的叶子。每个决策结点代表一个问题或决策,通常对应于待分类对象的属性。每一个叶子结点代表一种可能的分类结果。沿决策树从上到下遍历的过程中,在每个结点都会遇到一个测试,对每个结点上问题的不同的测试输出导致不同的分支,最后会到达一个叶子结点,这个过程就是利用决策树进行分类的过程,利用若干个变量来判断所属的类别。

2.CART算法

(1)选一个自变量,再选取的一个值,把维空间划分为两部分,一部分的所有点都满足,另一部分的所有点都满足,对非连续变量来说属性值的取值只有两个,即等于该值或不等于该值。

(2)递归处理,将上面得到的两部分按步骤(1)重新选取一个属性继续划分,直到把整个维空间都划分完。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值