决策树挑出好西瓜

本文介绍了使用ID3算法手动实现和sklearn库中ID3及CART算法对西瓜数据集进行分类的过程。首先,通过熵计算和数据集拆分实现了ID3算法,接着利用sklearn库的DecisionTreeClassifier分别应用ID3和CART算法。ID3算法基于信息增益,CART则是二分递归分割技术,两者皆构建了决策树模型。最后,对两种算法进行了总结,强调了CART算法的二分和剪枝策略特点。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、ID3算法

1.包引入

import numpy as np
import pandas as pd
import sklearn.tree as st
import math
import matplotlib
import os
import matplotlib.pyplot as plt

2.读取数据

data = pd.read_csv('C:/西瓜数据集.csv',header=None)
data

在这里插入图片描述

3.代码编写

def calcEntropy(dataSet):
    mD = len(dataSet)
    dataLabelList = [x[-1] for x in dataSet]
    dataLabelSet = set(dataLabelList)
    ent = 0
    for label in dataLabelSet:
        mDv = dataLabelList.count(label)
        prop = float(mDv) / mD
        ent = ent - prop * np.math.log(prop, 2)

    return ent

拆分数据集

def splitDataSet(dataSet, index, feature):
    splitedDataSet = []
    mD = len(dataSet)
    for data in dataSet:
        if(data[index] == feature):
            sliceTmp = data[:index]
            sliceTmp.extend(data[index + 1:])
            splitedDataSet.append(sliceTmp)
    return splitedDataSet

最优特征

def chooseBestFeature(dataSet):
    entD = calcEntropy(dataSet)
    mD = len(dataSet)
    featureNumber = len(dataSet[0]) - 1
    maxGain = -100
    maxIndex = -1
    for i in range(featureNumber):
        entDCopy = entD
        featureI = [x[i] for x in dataSet]
        featureSet = set(featureI)
        for feature in featureSet:
            splitedDataSet = splitDataSet(dataSet, i, feature)  # 拆分数据集
            mDv = len(splitedDataSet)
            entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
        if(maxIndex == -1):
            maxGain = entDCopy
            maxIndex = i
        elif(maxGain < entDCopy):
            maxGain = entDCopy
            maxIndex = i

    return maxIndex

寻找最多为标签

def mainLabel(labelList):
    labelRec = labelList[0]
    maxLabelCount = -1
    labelSet = set(labelList)
    for label in labelSet:
        if(labelList.count(label) > maxLabelCount):
            maxLabelCount = labelList.count(label)
            labelRec = label
    return labelRec

def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
    labelList = [x[-1] for x in dataSet]
    if(len(dataSet) == 0):
        return mainLabel(labelListParent)
    elif(len(dataSet[0]) == 1): #没有可划分的属性了
        return mainLabel(labelList)  #选出最多的label作为该数据集的标签
    elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Label
        return labelList[0]

    bestFeatureIndex = chooseBestFeature(dataSet)
    bestFeatureName = featureNames.pop(bestFeatureIndex)
    myTree = {bestFeatureName: {}}
    featureList = featureNamesSet.pop(bestFeatureIndex)
    featureSet = set(featureList)
    for feature in featureSet:
        featureNamesNext = featureNames[:]
        featureNamesSetNext = featureNamesSet[:][:]
        splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
        myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
    return myTree

画图

def readWatermelonDataSet():
    dataSet = data.values.tolist()
    featureNames =['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
    #获取featureNamesSet
    featureNamesSet = []
    for i in range(len(dataSet[0]) - 1):
        col = [x[i] for x in dataSet]
        colSet = set(col)
        featureNamesSet.append(list(colSet))
    
    return dataSet, featureNames, featureNamesSet


matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']


decisionNode = dict(boxstyle="sawtooth", fc="0.8")


leafNode = dict(boxstyle="round4", fc="0.8")


arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0

    firstStr = list(myTree.keys())[0]

    secondDict = myTree[firstStr]

    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0

    firstStr = list(myTree.keys())[0]

    secondDic = myTree[firstStr]

    for key in secondDic.keys():
        if type(secondDic[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDic[key])

        else:
            thisDepth = 1

        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
   
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
   
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    
    numLeafs = getNumLeafs(myTree=myTree)

    
    depth = getTreeDepth(myTree=myTree)

    
    firstStr = list(myTree.keys())[0]

    
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

    
    plotMidText(cntrPt, parentPt, nodeTxt)

    
    plotNode(firstStr, cntrPt, parentPt, decisionNode)

    
    secondDict = myTree[firstStr]

   
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

   
    for key in secondDict.keys():
       
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
           
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
           
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

   
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

dataSet, featureNames, featureNamesSet=readWatermelonDataSet()
testTree= createFullDecisionTree(dataSet, featureNames, featureNamesSet,featureNames)
createPlot(testTree)

4.结果

在这里插入图片描述

二、sklearn实现ID3、CART算法实现

1.ID3

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier

data = pd.read_csv('C:/西瓜数据集.csv',header=None)
data

label = LabelEncoder()    

for col in data[data.columns[:-1]]:
    data[col] = label.fit_transform(data[col])
data

# 采用ID3拟合
dtc = DecisionTreeClassifier(criterion='entropy')
# 进行拟合
dtc.fit(data.iloc[:,:-1].values.tolist(),data.iloc[:,-1].values) 
# 标签对应编码
result = dtc.predict([[0,0,0,0,0,0]])
#拟合结果
result

在这里插入图片描述

2.CART

# 采用CART拟合
dtc = DecisionTreeClassifier()
# 进行拟合
dtc.fit(data.iloc[:,:-1].values.tolist(),data.iloc[:,-1].values) 
# 标签对应编码
result = dtc.predict([[0,0,0,1,0,0]])
#拟合结果
result

在这里插入图片描述

三、总结

1.ID3算法

ID3算法的基本流程为:如果某一个特征能比其他特征更好的将训练数据集进行区分,那么将这个特征放在初始结点,依此类推,初始特征确定之后,对于初始特征每个可能的取值建立一个子结点,选择每个子结点所对应的特征,若某个子结点包含的所有样本属于同一类或所有特征对其包含的训练数据的区分能力均小于给定阈值,则该子结点为一个叶结点,其类别与该叶结点的训练数据类别最多的一致。重复上述过程直到特征用完或者所有特征的区分能力均小于给定阈值。如何衡量某个特征对训练数据集的区分能力呢,ID3算法通过信息增益来解决这个问题

2.CART算法

CART(Classification And Regression Tree)算法既可以用于创建分类树,也可以用于创建回归树。CART算法的重要特点包含以下三个方面:
二分(Binary Split):在每次判断过程中,都是对样本数据进行二分。CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,因此CART算法生成的决策树是结构简洁的二叉树。由于CART算法构成的是一个二叉树,它在每一步的决策时只能是“是”或者“否”,即使一个feature有多个取值,也是把数据分为两部分
单变量分割(Split Based on One Variable):每次最优划分都是针对单个变量。
剪枝策略:CART算法的关键点,也是整个Tree-Based算法的关键步骤。剪枝过程特别重要,所以在最优决策树生成过程中占有重要地位。有研究表明,剪枝过程的重要性要比树生成过程更为重要,对于不同的划分标准生成的最大树(Maximum Tree),在剪枝之后都能够保留最重要的属性划分,差别不大。反而是剪枝方法对于最优树的生成更为关键。

四、参考

1.西瓜决策树-sklearn实现
2.[机器学习-Sklearn]决策树学习与总结 (ID3, C4.5, C5.0, CART)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ivan@Xiang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值