决策树选取最优特征

Python3《机器学习实战》学习笔记(二):决策树基础篇之让我们从相亲说起_Jack-Cui-优快云博客

 本文理论知识依据来源上文,以下为上述链接代码个人实现部分

import math


def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],  # 数据集
               [0, 0, 0, 1, 'no'],
               [0, 1, 0, 1, 'yes'],
               [0, 1, 1, 0, 'yes'],
               [0, 0, 0, 0, 'no'],
               [1, 0, 0, 0, 'no'],
               [1, 0, 0, 1, 'no'],
               [1, 1, 1, 1, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [2, 0, 1, 2, 'yes'],
               [2, 0, 1, 1, 'yes'],
               [2, 1, 0, 1, 'yes'],
               [2, 1, 0, 2, 'yes'],
               [2, 0, 0, 0, 'no']]
    labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
    return dataSet, labels


def calShannonEnt(dataSet):
    totalNum = len(dataSet)
    labelCount = {}
    for data in dataSet:
        label = data[-1]
        labelCount[label] = labelCount.get(label, 0) + 1
    shannonEnt = 0.0
    for label in labelCount:
        prob = -float(labelCount[label] / totalNum) * math.log2(labelCount[label] / totalNum)
        shannonEnt += prob
    return shannonEnt


def handle_feature(i, value, dataset):
    line_list = []
    lineCount = {}
    for row in dataSet:
        if row[i] == value:
            line_list.append(row)
            lineCount[value] = lineCount.get(value, 0) + 1
    Ent = calShannonEnt(line_list)
    return Ent, lineCount.get(value)


def calEachFeature(dataSet):
    feature_num = len(dataSet[0]) - 1
    totalnum = len(dataSet)
    shannonEnt = calShannonEnt(dataSet)
    feature_ent_list = []
    for i in range(feature_num):
        s_list = [[] for i in range(feature_num)]
        for line in dataSet:
            s = line[i]
            s_list[i].append(s)
        s_list[i] = list(set(s_list[i]))
        newshang = shannonEnt
        for value in s_list[i]:
            ent, num = handle_feature(i, value, dataSet)
            signal_Ent = (num / totalnum) * ent
            newshang = round(newshang - signal_Ent, 3)
        feature_ent_list.append(newshang)
        print(f'第{i}个特征的增益为:{newshang}')

    index = 0
    for i in range(1, len(feature_ent_list)):
        if feature_ent_list[i] > feature_ent_list[index]:
            index = i
    return index


if __name__ == '__main__':
    dataSet, labels = createDataSet()
    print(dataSet)
    shannonEnt = calShannonEnt(dataSet)
    print('香农熵为:', shannonEnt)
    index = calEachFeature(dataSet)
    print('最优特征索引值:' + str(index))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值