代码来源
机器学习之决策树ID3(python3实现)
ID3决策树的编程实验
# -*- coding: utf-8 -*-
# @Time : 2018/10/11 17:54
# @Author : squabbySheep
# @Email : squabbySheep@163.com
# @File : 实验一.py
# @Software: PyCharm Community Edition
# ID3决策树的编程实验
'''
目的:理解ID3算法原理
内容:根据以上数据集建立决策树,可以自己找一个公开数据集或者实际数据集来做实验。
'''
import math
def createDataSet():
labels = ['年龄', '收入', '学生', '信用', '买了电脑']
dataSet = [
['<30', '高', '否', '一般', '否'],
['<30', '高', '否', '好', '否'],
['30-40', '高', '否', '一般', '是'],
['>40', '中等', '否', '一般', '是'],
['>40', '低', '是', '一般', '是'],
['>40', '低', '是', '好', '否'],
['30-40', '低', '是', '好', '是'],
['<30', '中', '否', '一般', '否'],
['<30', '低', '是', '一般', '是'],
['>40', '中', '是', '一般', '是'],
['<30', '中', '是', '好', '是'],
['30-40', '中', '否', '好', '是'],
['30-40', '高', '是', '一般', '是'],
['>40', '中', '否', '好', '否'],
]
return dataSet, labels
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
# 为分类创建字典
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts.setdefault(currentLabel, 0)
labelCounts[currentLabel] += 1
# 计算香农墒
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt += prob * math.log2(1 / prob)
return shannonEnt
# 定义按照某个特征进行划分的函数 splitDataSet
# 输入三个变量(带划分数据集, 特征,分类值)
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reduceFeatVec)
return retDataSet #返回不含划分特征的子集
# 定义按照最大信息增益划分数据的函数
def chooseBestFeatureToSplit(dataSet):
numFeature = len(dataSet[0]) - 1
# print(numFeature)
baseEntropy = calcShannonEnt(dataSet)
bestInforGain = 0
bestFeature = -1
for i in range(numFeature):
featList = [number[i] for number in dataSet] #得到某个特征下所有值
uniqualVals = set(featList) #set无重复的属性特征值
newEntrogy = 0
#求和
for value in uniqualVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) #即p(t)
newEntrogy += prob * calcShannonEnt(subDataSet) #对各子集求香农墒
infoGain = baseEntropy - newEntrogy #计算信息增益
# print(infoGain)
# 最大信息增益
if infoGain > bestInforGain:
bestInforGain = infoGain
bestFeature = i
return bestFeature
# 投票表决代码
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount.setdefault(vote, 0)
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=lambda i: i[1], reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# print(dataSet)
# print(classList)
# 类别相同,停止划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 判断是否遍历完所有的特征,是,返回个数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 按照信息增益最高选择分类特征属性
bestFeat = chooseBestFeatureToSplit(dataSet) # 分类编号
bestFeatLabel = labels[bestFeat] # 该特征的label
myTree = {bestFeatLabel: {}}
del (labels[bestFeat]) # 移除该label
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] # 子集合
# 构建数据的子集合,并进行递归
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
"""
:param inputTree: 决策树
:param featLabels: 属性特征标签
:param testVec: 测试数据
:return: 所属分类
"""
firstStr = list(inputTree.keys())[0] # 树的第一个属性
sendDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
classLabel = None
for key in sendDict.keys():
if testVec[featIndex] == key:
if type(sendDict[key]).__name__ == 'dict':
classLabel = classify(sendDict[key], featLabels, testVec)
else:
classLabel = sendDict[key]
return classLabel
# 将决策树进行存储
def storeTree(inputTree,filename):
import pickle
fw=open(filename, 'wb') #pickle默认方式是二进制,需要制定'wb'
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr=open(filename, 'rb')#需要制定'rb',以byte形式读取
return pickle.load(fr)
if __name__ == '__main__':
dataSet, labels = createDataSet()
r = chooseBestFeatureToSplit(dataSet)
myTree = createTree(dataSet, labels)
# filename = 'myTree'
# storeTree(myTree, filename)
# myTree = grabTree(filename)
# print('myTree=', myTree)
# 测试
testLabels = ['年龄', '收入', '学生', '信用']
testVec = []
print("例子:['<30'|'30-40'|'>40', '高'|'中等'|'低', '是'|'否', '一般'|'好']")
for label in testLabels:
testVec.append(input('请输入'+label+'的值:'))
result = classify(inputTree=myTree, featLabels=testLabels, testVec=testVec)
print('买了电脑吗?', result)
编程实验结果截图
