首先,我们使用 treelib 库来显示树结构 :
ps : 如果 treelib 输出一堆乱码, 可以点进Tree修改 tree.py 大概 930 行左右的部分(去掉encode就行了)
if stdout:
print(self._reader) # print(self._reader.encode("utf-8"))
else:
return self._reader
将字典转换为treelib 中可显示的 map的python 函数编写如下:
import copy
from treelib import Tree
def dict2map(dic):
if not isinstance(dic, dict):
raise TypeError("input should be dict")
map = {
}
_dict2map_cb(dic if len(dic) <= 1 else {
'root':dic}, map, parent = [])
return map
def _dict2map_cb(dic, map, parent=[]):
"""
Create map object in treelib by dict
:param dic: Python dict
:param map: Use {} as map
:param node_name: If None, use the first key of dic as root
but when multiple items in json, pass "root"
:param parent: Parent node array
:return:
"""
for key, val in dic.items():
node_name_new = '-'.join(parent)
root_name = '-'.join(parent + [key])
if isinstance(val, dict):
map[root_name] = node_name_new if parent!=[] else None # when root node,use None
_dict2map_cb(val, map, parent=parent + [key])
else:
map[root_name + " : " + str(val)] = node_name_new if parent!=[] else None
if __name__ == "__main__":
a = {
"hello": {
"word": 2}}
b = {
'decision 0': {
'target 1': 256, 'decision 3': {
'target 0': 128, 'target 1': 256}, 'decision 2': {
'target 0': 256, 'target 1': 128}}}
c = {
'hi': {
"w": 3}, 'this':{
'e':4}}
Tree.from_map(dict2map(a)).show()
Tree.from_map(dict2map(b)).show(line_type="ascii-em")
Tree.from_map(dict2map(c)).show(line_type="ascii-em")
决策树的参考文章是 《机器学习苏娜发原理与编程实践》郑捷著, 具体是分类如下的问题 :
计数 | 年龄 | 收入 | 学生 | 信誉 | 是否购买 |
---|---|---|---|---|---|
64 | 青 | 高 | 否 | 良 | 不买 |
64 | 青 | 高 | 否 | 优 | 不买 |
128 | 中 | 高 | 否 | 良 | 买 |
60 | 老 | 中 | 否 | 良 | 买 |
64 | 老 | 低 | 是 | 良 | 买 |
64 | 老 | 低 | 是 | 优 | 不买 |
64 | 中 | 低 | 是 | 优 | 买 |
128 | 青 | 中 | 否 | 良 | 不买 |
64 | 青 | 低 | 是 | 良 | 买 |
132 | 老 | 中 | 是 | 良 | 买 |
64 | 青 | 中 | 是 | 优 | 买 |
32 | 中 | 中 | 否 | 优 | 买 |
32 | 中 | 高 | 是 | 良 | 买 |
64 | 老 | 中 | 否 | 优 | 不买 |
显然容易构建出如下的决策树, 但这个决策树不是最优的
决策树的原理在这里不进行讲解, 将上面的表格保存为seals_data.xlsx, 并将转换的脚本保存为 dict_to_map.py, 即可直接运行下面的ID3决策树代码:
import numpy
import numpy as np
import copy
import pandas as pd
from sklearn.preprocessing import LabelEncoder # encoder labels
from treelib import Tree, Node
from sklearn.datasets._base import Bunch
from dict_to_map import dict2map
class ID3_Tree():
""" ID3 decision Tree Algorithm """
def __init__(self, counts = None, data =