CART算法的数学详细计算步骤,以及CART算法生成决策树的过程,可以详见:
【机器学习】【决策树】CART算法,用样本集一步一步详解如何求:基尼指数,最优特征,最优切分点
本章仅提供CART算法的python实现代码,以python类面向对象方式实现~
代码写的好艰难,完全一行一行码出来了,还是写出来了,现在已经凌晨0:13了。
1.样本数据集
def create_samples():
'''''
提供训练样本集
每个example由多个特征值+1个分类标签值组成
比如第一个example=['youth', 'no', 'no', '1', 'refuse'],此样本的含义可以解读为:
如果一个人的条件是:youth age,no working, no house, 信誉值credit为1
则此类人会被分类到refuse一类中,即在相亲中被拒绝(也可以理解为银行拒绝为此人贷款)
每个example的特征值类型为:
['age', 'working', 'house', 'credit']
每个example的分类标签class_label取值范围为:'refuse'或者'agree'
'''
data_list = [['youth', 'no', 'no', '1', 'refuse'],
['youth', 'no', 'no', '2', 'refuse'],
['youth', 'yes', 'no', '2', 'agree'],
['youth', 'yes', 'yes', '1', 'agree'],
['youth', 'no', 'no', '1', 'refuse'],
['mid', 'no', 'no', '1', 'refuse'],
['mid', 'no', 'no', '2', 'refuse'],
['mid', 'yes', 'yes', '2', 'agree'],
['mid', 'no', 'yes', '3', 'agree'],
['mid', 'no', 'yes', '3', 'agree'],
['elder', 'no', 'yes', '3', 'agree'],
['elder', 'no', 'yes', '2', 'agree'],
['elder', 'yes', 'no', '2', 'agree'],
['elder', 'yes', 'no', '3', 'agree'],
['elder', 'no', 'no', '1', 'refuse']]
feat_list = ['age', 'working', 'house', 'credit']
return data_list, feat_list
2.运行结果-cart决策树的字典
max_n_feats = 3时
tree_dict = {
house :{
yes : agree
no :{
working : {'yes': 'agree', 'no': 'refuse'}
}
}
}
3.运行结果-决策树的绘制图形
max_n_feats = 3时
4.核心代码讲解
核心代码是:类class CCartTree(object)中的work()接口和create_tree()接口
work()是cart算法,生成最优特征,最优切分点,最优叶节点等等
create_tree()是递归生成cart决策树字典
树的限制递归生成阈值:
max_n_feats,当剩下的样本集的特征数少于max_n_feats,将不再进行继续生成。
也可以提供gini阈值~
5.代码
# -*- coding: utf-8 -*-
"""
@author: 蔚蓝的天空Tom
Aim:给定样本集和特征列表,计算每个特征的基尼指数,选取最优特征,选取最优分切点
Aim:生成CART决策树的字典形式
Aim:根据决策树字典绘制CART决策树图形
cart_dtree.py
"""
import matplotlib.pyplot as plt
def print_dict(src_dict, level, src_dict_namestr=''):
'''
逐行打印dict
:param self:类实例自身
:param src_dict:被打印的dict
:param level:递归level,初次调用为level=0
:param src_dict_namestr:对象变量名称字符串
'''
if isinstance(src_dict, dict):
tab_str = '\t'
for i in range(level):
tab_str += '\t'
if 0 == level:
print(src_dict_namestr,'= {')
for key, value in src_dict.items():
if isinstance(value, dict):
has_dict = False
for k,v in value.items():
if isinstance(v, dict):
has_dict = True
if has_dict:
print(tab_str,key,":{")
print_dict(value, level + 1)
else:
print(tab_str,key,':',value)
else:
print(tab_str,key,': '