其实讲道理
非递归建树和查询
应该开个class栈存储每一层结点的全部信息,但是窝不打算讲道理,就开了三个栈·····
数据用的点击打开链接这个
python用起来真是懒死了,然而代码还是写了这么长······
菜炸了·····
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 7 18:54:11 2018
@author: jkrs
"""
from math import log
import struct
import numpy as np
import random
import operator
from sklearn.model_selection import train_test_split
def ent(data):
feat = {}
for feature in data:
curlabel = feature[-1]
if curlabel not in feat:
feat[curlabel] = 0
feat[curlabel] += 1
s = 0.0
num = len(data)
for it in feat:
p = feat[it] * 1.0 / num
s -= p * log(p,2)
return s
def split_data(data,i,value):
newdata = []
for row in data:
if row[i] == value:
temp = row[:i]
temp.extend(row[i + 1:])
newdata.append(temp)
return newdata
def choosebestfeature(data):
num = len(data[0]) - 1
S = ent(data)
maxgain = -1.0
bestfeature = -1
for i in range(num):
curlabel = [it[i] for it in data]
curlabel = set(curlabel)
if len(curlabel) == 1:
continue
s = 0.0
split = 0.0
for value in curlabel:
subdata = split_data(data,i,value)
p = len(subdata) * 1.0 / len(data)
s += p * ent(subdata)
split -= p * log(p,2)
if split == 0:
continue
gain = (S - s) / split
if gain > maxgain:
maxgain = gain
bestfeature = i
return bestfeature
def classify(tree,feature,value):
treeStack = []
treeStack.append(tree)
while len(treeStack) > 0:
tree = treeStack.pop()
root = list(tree.keys())[0]
sons = tree[root]
i = feature.index(root)
for key in sons.keys():
if key == value[i]:
if type(sons[key]).__name__ == 'dict':
treeStack.append(sons[key])
else:
return sons[key]
return -1
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def build(data,feature):
tree = {}
nodeStack = []
nodeStack.append(tree)
featureStack = []
featureStack.append(feature)
dataStack = []
dataStack.append(data)
while len(dataStack) > 0:
data = dataStack.pop()
curfeature = (featureStack.pop())[:]
i = choosebestfeature(data)
bestfeature = curfeature[i]
curnode = nodeStack.pop()
curnode[bestfeature] = {}
del curfeature[i]
value = [it[i] for it in data]
value = set(value)
for eachvalue in value:
newdata = split_data(data,i,eachvalue)
curlabel = [it[-1] for it in newdata]
if curlabel.count(curlabel[0]) == len(curlabel):
label = curlabel[0]
curnode[bestfeature][eachvalue] = label
elif len(curlabel) <= 10 or len (data[0]) == 1:
label = majorityCnt(curlabel)
curnode[bestfeature][eachvalue] = label
else:
newfeature = curfeature[:]
featureStack.append(newfeature)
curnode[bestfeature][eachvalue] = {}
dataStack.append(newdata)
nodeStack.append(curnode[bestfeature][eachvalue])
return tree
#查看树的深度
def dfs(tree,deep,sample):
if (type(tree) != sample):
return deep
cnt = 0
for key in tree.keys():
cnt = max(cnt,dfs(tree[key],deep + 1,sample))
return cnt
def read_file():
feature = open('features.txt').readline()
feature = feature.split(',')
File = open('dataset(label including).txt').readlines()
data = []
label = []
for line in File:
line = line.strip().strip('\n')
data.append(line.split(','))
for i in range(len(data)):
label.append(data[i][-1])
del data[i][-1]
return feature,data,label
#读取数据
def main():
feature,data,label = read_file()
m = len(data)
n = len(data[0]) - 1
#分裂train_set和test_set
x_train, x_test, y_train, y_test = train_test_split(data,label, test_size = 0.3)
train_data = x_train[:]
for i in range(len(train_data)):
train_data[i].append(y_train[i])
#训练
tree = build(train_data,feature)
num = len(y_test)
res = []
for i in range(num):
#求解
res.append(classify(tree,feature,x_test[i]))
cnt = 0
for i in range(num):
if y_test[i] == res[i]:
cnt += 1
print('precise = ',cnt * 1.0 / num)
if __name__ == '__main__':
main()