决策树(ID3)

决策树是一种基于特征选择的分类方法,通过信息增益来选取最佳特征进行数据集划分。本文介绍了决策树的基本思想,以海洋生物分类为例阐述了决策过程,并详细讲解了ID3算法,包括信息熵和信息增益的概念。最后提到了Python实现ID3算法的参考资源。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

简介

决策树(Decision Tree)是一种逼近离散值目标函数的方法,在这个方法中学习到的函数表示为一颗决策树。学习到的决策树也能再表示为多个if-then的规则,提高可读性。决策树算法是最流行的归纳推理算法之一。
决策树的核心思想是对于给定的训练样本集,每次选取一个特征将样本集切分为若干个子集,递归地对每一个子集进行特征选取和样本切分,直到样本子集中的元素分类都相同或者没有剩余特征可以选取。如果出现没有候选特征继续切分的情况,将投票采取“主分类”作为改决策路径下的数据分类。每一次特征选取的过程就是决策的过程,所有的决策路径便构成了一棵决策树。
以海洋生物样本集为例,我们将通过不浮出水面是否可以生存以及是否有脚蹼两个特征来判断一个海洋生物是否属于鱼类。
这里写图片描述
首先,根据第一个特征(不浮出水面是否可以生存)将样本集切分成两个子集s1=(1,2,3), s2=(4,5),s1元素分类不一致,继续切分;s2分类一致,终止;
最后,根据第二个特征(是否有脚蹼)将s1切分成两个子集s3=(1,2), s4=(3),此时s3和s4分类都一致,终止。
整个决策过程就构成了一棵决策树,如下图所示:

ID3算法

决策树构造的整个过程中,关键点在于对于不同的样本集如何决策,即选取哪个特征来实现对数据集的切分。直观地讲,划分数据集的最大原则是:将分类混乱的数据分类更加一致。组织杂乱无章数据的一种方法就是使用信息论度量信息。
下面先引入一些必要的信息论的定义。
信息(information):如果待分类的事物可能划分在多个分类之中,则符号xi的信息定义为:l(xi) = -log(p(xi), 2)
方便起见,这里利用了python math库中的log函数的写法,其中p(xi)是该分类在集合中出现的后验概率。
熵(entropy):熵定义为信息的期望,H = sigma(-p(xi)*log(p(xi),2)).
信息论中熵的一种解释是,熵确定了要编码集合S中任意成员(即以均匀的概率随机抽样出一个成员)的分类需要的最好的二进制位数。
通俗一点理解,熵反应了一个样本集分类的混杂程度,熵越大越混杂;熵越小,分类越单一;特殊地,样本集分类一致的时候,熵等于0。
信息增益(information gain):在划分数据集前后熵发生的变化成为信息增益
也可以通俗地理解为,信息增益反应了切分数据集前后样本集分类趋于分类一致性的程度
ID3算法的核心思想:按照获取最大信息增益的方法选取特征划分数据集。

Python实现ID3算法

#coding=utf-8
'''
@author: slowalker
@date:2015-05-10
'''

from math import log
import operator

def createDataSet():
    '''
    海洋生物数据,根据不浮出水面是否可以生存(no surfacing)、是否有脚蹼(flippers)两个特征来判断是否属于鱼类。
    '''
    dataSet = [[1, 1, 'yes'],
           [1, 1, 'yes'],
           [1, 0, 'no'],
           [0, 1, 'no'],
           [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels
#计算一个样本集的信息熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelsCounts = {}
    for featVec in dataSet:
    currentLabel = featVec[-1]
    if currentLabel not in labelsCounts.keys():
        labelsCounts[currentLabel] = 1
    else:
        labelsCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelsCounts:
    prob = 1.0 * labelsCounts[key] / numEntries
    shannonEnt -= prob * log(prob, 2) #log base 2
    return shannonEnt
#获取axis维特征值为value的样本子集(不含第axis维特征)
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
    if featVec[axis] == value:
        reducedFeatVec = featVec[:axis]
        reducedFeatVec.extend(featVec[axis+1:])
        retDataSet.append(reducedFeatVec)
    return retDataSet
#ID3算法选择最优特征决策
#实质上是基于信息熵和信息增益的贪心算法(Greedy Algorithm)
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
    featList = [example[i] for example in dataSet] #列表推导式用法(list comprehension)
    uniqueVals = set(featList)
    newEntropy = 0.0
    for value in uniqueVals:
        subDataSet = splitDataSet(dataSet, i, value)
        prob = 1.0 * len(subDataSet) / len(dataSet)
        newEntropy += prob * calcShannonEnt(subDataSet)
    infoGain = baseEntropy - newEntropy
    if infoGain > bestInfoGain:
        bestInfoGain = infoGain
        bestFeature = i
    return bestFeature
#投票获取主分类
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
    if vote not in classCount.keys():
        classCount[vote] = 1
    else:
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
#构造决策树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    #Python List count() Mtheod Description: The method count() returns count of how many times obj occurs in list.
    if classList.count(classList[0]) == len(classList):
    return classList[0] #stop splitting when all the classes are equal
    if len(dataSet[0]) == 1:
    return majorityCnt(classList) #stop splitting when there are no more features in dataSet
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for val in uniqueVals:
    subLabels = labels[:]
    myTree[bestFeatLabel][val] = createTree(splitDataSet(dataSet, bestFeat, val), subLabels)
    return myTree
#使用决策树进行分类
def classify(inputTree, featLables, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLables.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
    classLabel = classify(valueOfFeat, featLables, testVec)
    else:
    classLabel = valueOfFeat
    return classLabel

if __name__ == '__main__':
    print createDataSet.__doc__
    dataSet, labels = createDataSet()
    print dataSet
    print labels
    print calcShannonEnt(dataSet)
    print splitDataSet(dataSet, 0, 1)
    print splitDataSet(dataSet, 1, 1)
    print chooseBestFeatureToSplit(dataSet)
    print majorityCnt(['yes', 'yes', 'no', 'yes', 'no'])
    print createTree(dataSet, labels)

参考文档

  • 《Machine Learning》 by Tom M.Mitchell
  • 《Machine Learning in Action》 by Peter Harrington
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值