# -*- coding: utf-8 -*-
"""
Created on Tue Jan 30 09:48:53 2018
Email: Eric2014_Lv@sjtu.edu.cn
@author: DidiLv
Python version: 3.5
"""
from math import *
import operator
def createDataSet():
dataSet = [[1, 1, "yes"],
[1, 1, "yes"],
[1, 0, "no"],
[0, 1, "no"],
[0, 1, "no"]]
labels = ["no surfacing", "flippers"]
return dataSet, labels
def calShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # "-1": depends on the data structure
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= log(prob, 2) * prob
return shannonEnt
def splitDataSet(dataSet, axis, value):
# axis: the feature index of dataSet
# value: the 'axis'th feature value
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# note that it's slice operation in python
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
# make clear about the difference of "extend" and "append"
retDataSet.append(reducedFeatVec)
return retDataSet
# important 1:
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calShannonEnt(dataSet) # the base entropy for comparision
bestInfoGain = 0.0;
bestFeature = -1
# create a subdataSet to compute the shannon entropy
for i in range(numFeatures):
# step 1: extract the ith feature
featList = [example[i] for example in dataSet]
# step 2: "set" the related feature values for "classification"
uniqueVals = set(featList)
newEntropy = 0.0
# step 3: calculate the shannon entropy for subdataSet
for value in uniqueVals:
# step 3.1: classification
subDataSet = splitDataSet(dataSet, i, value)
# step 3.2: calculation
prob = float(len(subDataSet)) / len(dataSet)
newEntropy += prob * calShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classCount:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
# stop criterion
# 1. NO other class
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(dataSet):
return classList[0]
# 2. NO feature
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# create the Decesion Tree
bestFeatindex = chooseBestFeatureToSplit(dataSet) # return the index of the best feature
bestFeatLabel = labels[bestFeatindex]
myTree = {bestFeatLabel:{}}
del(labels[bestFeatindex])
bestfeatValues = [example[bestFeatindex] for example in dataSet]
uniqueVals = set(bestfeatValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeatindex, value), subLabels)
return myTree
def main():
myDat, labels = createDataSet()
print("Project1: calShannonEnt: -->")
shannonEnt = calShannonEnt(myDat)
print("ShannonEntropy =", shannonEnt)
print("Project2: splitDataSet: -->")
splitData = splitDataSet(myDat, 0, 1)
print(splitData)
print("Project3: chooseBestFeatureToSplit: -->")
bestFeature = chooseBestFeatureToSplit(myDat)
print("The best feature for myDat is: %d" %bestFeature)
print("Project4: createTree: -->")
myTree = createTree(myDat, labels)
print("My_Tree is: ", myTree)
if __name__ == "__main__":
main()