引用数据集获取:
程序:
# -*- coding: utf-8 -*-
"""
Created on Sun Jan 6 23:02:02 2019
@author: Jack Lee
"""
import math
def createDataSet():
dataSet = [
# 1
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
# 2
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
# 3
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
# 4
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
# 5
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
# 6
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
# 7
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
# 8
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
# ----------------------------------------------------
# 9
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
# 10
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
# 11
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
# 12
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
# 13
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
# 14
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
# 15
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
# 16
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
# 17
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
]
# 特征值列表
labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
# 特征对应的所有可能的情况
labels_full = {}
for i in range(len(labels)):
labelList = [example[i] for example in dataSet]
uniqueLabel = list(set(labelList))
labels_full[labels[i]] = uniqueLabel
return dataSet, labels, labels_full
class TreeNode():
def __init__(self, cls_feature, cls_value, data, features):
self.childs = []#节点的子节点
self.cls_feature = cls_feature#节点的划分属性
self.cls_value = cls_value#划分属性取值
self.data = data#节点的数据 , a list of list
self.features = features#节点的属性们
def generate_childs(self, TreeNodes):
try:
for TreeNode in TreeNodes:
self.childs.append(TreeNode)
except TypeError:
self.childs.append(TreeNodes)
def traverse(self):
for child in self.childs:
try:
if child.label:
print("IF %s is %s, "%(child.cls_feature, child.cls_value))
print("It is: %s.\n"%child.label)
except AttributeError:
print("IF %s is %s, "%(child.cls_feature, child.cls_value))
child.traverse()
def convert_foliage(self,label):#将节点标记为叶节点
return Foliage(label=label)
def most_label(self):#出现最多的类别
return '好瓜' if list(map(lambda x:x[-1],self.data)).count('好瓜') > list(map(lambda x:x[-1],self.data)).count('坏瓜') else '坏瓜'
def get_subset(self, feature):#获得数据中包含feature的子集
t = []
for data in self.data:
if feature in data:
t.append(data)
return t
class Foliage(TreeNode):
def __init__(self, label):
TreeNode.__init__(self, cls_feature=None,cls_value=None, data=None, features=None)
self.label = label
class ID3():
def __init__(self):
self.name = 's'
self.get_data()
self.build_decision_tree(self.dataset,self.features)
def get_data(self):
self.dataset,self.features,self.dict = createDataSet()
def build_decision_tree(self,data,features):
#print(features)
t = TreeNode(cls_feature=None,cls_value=None, data=data, features=features)
if self.is_same_label(t.data):#若所有数据都为一类
print("这是:",t.data[0][-1])
t = t.convert_foliage(t.data[0][-1])
return t
if t.features is None or self.is_same_feature(t.data):#若没有数据 或 所有数据的属性都相同
t = t.convert_foliage(t.most_label())
print("这是:",t.most_label())
return t
a_ = self.get_best_branch(t)#划分属性a_
print("if:",a_)
for feature in self.dict[a_]:#划分属性取值feature
print(feature)
_features = features[:]
D_v = t.get_subset(feature)
if len(D_v) == 0:
print("这是:",t.most_label())
t.generate_childs(Foliage(t.most_label()))
else:
t.cls_feature = a_
t.cls_value = feature
_features.remove(a_)
child = self.build_decision_tree(D_v, _features)
if child is not None:
t.generate_childs(child)
return t
def is_same_label(self,Data):#样本数据是否为同一种类
a = []
for data in Data:
a.append(data[-1])
if len(set(a)) == 1:
return True
else:
return False
def is_same_feature(self,Data):#判断属性是否为同一中
for i in range(len(Data[0])-1):
a = []
for data in Data:
a.append(data[i])
if len(set(a)) == 1:
continue
else:
return False
return True
def get_best_branch(self,treenode):
l = list(map(lambda x: self.get_Gain(x, treenode.data), treenode.features))
return treenode.features[l.index(max(l))]
def get_Gain(self, feature, data):
gain = self.get_Ent(data)
total = len(data)
for eachfeature in self.dict[feature]:
cnt = 0
D_v = []
for eachdata in data:
if eachfeature in eachdata:
cnt += 1
D_v.append(eachdata)
gain += -cnt/total * self.get_Ent(D_v)
return gain
def get_Ent(self,data):
a = []
for each in data:
a.append(each[-1])
cnt = len(data)
good = a.count('好瓜')
if good == 0 or good == cnt:
return 0
return -(good/cnt * math.log(good/cnt,2) + (1-good/cnt) * math.log(1-good/cnt,2))
id3 = ID3()