决策树Python实现代码

本文介绍了一种基于西瓜数据集2.0的决策树算法实现,通过ID3算法创建决策树来预测西瓜的好坏。文章详细展示了数据集的内容、特征值及所有可能情况,并提供了完整的Python代码实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

引用数据集获取:

西瓜数据集2.0获取

程序:

# -*- coding: utf-8 -*-
"""
Created on Sun Jan  6 23:02:02 2019

@author: Jack Lee
"""

import math

def createDataSet():

    dataSet = [
        # 1
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 2
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 3
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 4
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 5
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 6
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
        # 7
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
        # 8
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],

        # ----------------------------------------------------
        # 9
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
        # 10
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
        # 11
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
        # 12
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
        # 13
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 14
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 15
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
        # 16
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
        # 17
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]

    # 特征值列表
    labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']

    # 特征对应的所有可能的情况
    labels_full = {}

    for i in range(len(labels)):
        labelList = [example[i] for example in dataSet]
        uniqueLabel = list(set(labelList))
        labels_full[labels[i]] = uniqueLabel

    return dataSet, labels, labels_full

class TreeNode():
    def __init__(self, cls_feature, cls_value, data, features):
        self.childs = []#节点的子节点
        self.cls_feature = cls_feature#节点的划分属性
        self.cls_value = cls_value#划分属性取值
        self.data = data#节点的数据 , a list of list
        self.features = features#节点的属性们
        
    def generate_childs(self, TreeNodes):
        try:
            for TreeNode in TreeNodes:
                self.childs.append(TreeNode)
        except TypeError:
            self.childs.append(TreeNodes)
        
    def traverse(self):
        for child in self.childs:
            try:
                if child.label:
                    print("IF %s is %s, "%(child.cls_feature, child.cls_value))
                    print("It is: %s.\n"%child.label)
            except AttributeError:
                print("IF %s is %s, "%(child.cls_feature, child.cls_value))
            child.traverse()
            
    def convert_foliage(self,label):#将节点标记为叶节点
        return Foliage(label=label)
        
                
    def most_label(self):#出现最多的类别
        return '好瓜' if list(map(lambda x:x[-1],self.data)).count('好瓜') > list(map(lambda x:x[-1],self.data)).count('坏瓜') else '坏瓜'
    
    def get_subset(self, feature):#获得数据中包含feature的子集
        t = []
        for data in self.data:
            if feature in data:
                t.append(data)
        return t
 

class Foliage(TreeNode):
    def __init__(self, label):
        TreeNode.__init__(self, cls_feature=None,cls_value=None, data=None, features=None)
        self.label = label   

class ID3():
    def __init__(self):
        self.name = 's'
        self.get_data()
        self.build_decision_tree(self.dataset,self.features)
        
    def get_data(self):
        self.dataset,self.features,self.dict = createDataSet()
        
    def build_decision_tree(self,data,features):
        #print(features)
        t = TreeNode(cls_feature=None,cls_value=None, data=data, features=features)
        if self.is_same_label(t.data):#若所有数据都为一类
            print("这是:",t.data[0][-1])
            t = t.convert_foliage(t.data[0][-1])
            return t
        if t.features is None or self.is_same_feature(t.data):#若没有数据 或 所有数据的属性都相同
            t = t.convert_foliage(t.most_label())
            print("这是:",t.most_label())
            return t
        a_ = self.get_best_branch(t)#划分属性a_
        print("if:",a_)
        for feature in self.dict[a_]:#划分属性取值feature
            print(feature)
            _features = features[:]
            D_v = t.get_subset(feature)
            if len(D_v) == 0:
                print("这是:",t.most_label())
                t.generate_childs(Foliage(t.most_label()))
            else:
                t.cls_feature = a_
                t.cls_value = feature
                _features.remove(a_)
                child = self.build_decision_tree(D_v, _features)
                if child is not None:
                    t.generate_childs(child)
        return t
        
            
            
    def is_same_label(self,Data):#样本数据是否为同一种类
        a = []
        for data in Data:
            a.append(data[-1])
        if len(set(a)) == 1:
            return True
        else: 
            return False
        
    def is_same_feature(self,Data):#判断属性是否为同一中
        for i in range(len(Data[0])-1):
            a = []
            for data in Data:
                a.append(data[i])
            if len(set(a)) == 1:
                continue
            else: 
                return False
        return True
  
    def get_best_branch(self,treenode):
        l = list(map(lambda x: self.get_Gain(x, treenode.data), treenode.features))
        return treenode.features[l.index(max(l))]
        
    def get_Gain(self, feature, data):
        gain = self.get_Ent(data)
        total = len(data)
        for eachfeature in self.dict[feature]:
            cnt = 0
            D_v = []
            for eachdata in data:
                if eachfeature in eachdata:
                    cnt += 1
                    D_v.append(eachdata)
            gain += -cnt/total * self.get_Ent(D_v)
        return gain
  
    def get_Ent(self,data):
        a = []
        for each in data:
            a.append(each[-1])
        cnt = len(data)
        good = a.count('好瓜')
        if good == 0 or good == cnt:
            return 0
        return -(good/cnt * math.log(good/cnt,2) + (1-good/cnt) * math.log(1-good/cnt,2))
        
id3 = ID3()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值