决策树及python实现

本文详细介绍了决策树的学习过程,包括特征选择、ID3、C4.5和CART算法,以及决策树的剪枝方法。重点讨论了信息增益和基尼指数在特征选择中的作用,并提供了Python代码实现。最后,讨论了预剪枝和后剪枝两种剪枝策略,以提高决策树的泛化能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


决策树是一种基本的分类和回归方法。决策树学习通常包含三个部分:特征选择,决策树的生成和决策树的修剪。

1 决策树模型与学习

定义1(决策树) 分类决策树模型是一种描述对实例及进行分类的树形结构。决策树由节点和有向边组成。结点有两种类型:内部节点和叶节点。内部节点表示一个特征或属性,叶节点表示一个类。

在这里插入图片描述
决策树学习本质上是从训练数据集中归纳出一组分类规则。与训练数据集不相矛盾的决策树(即能对训练数据进行正确分类的决策树)可能有多个,也可能一个都没有。我们需要的是一个与训练数据矛盾较小的决策树,同时具有很好的泛化能力。决策树学习的损失函数一般是正则化的极大似然函数。决策树学习的策略是以损失函数为目标函数的最小化。

决策树的生成是一个递归过程。在决策树基本算法中,有三种情形会导致递归返回:
(1)当前结点包含的样本全属于同一类别,无需划分;
(2)当前属性集为空,或是所有样本在所有属性上取值均相同,无法划分;
解决:把当前节点标记为叶节点,并将其类别设定为该节点所含样本最多的类别;利用了当前结点的后验分布。
(3)当前结点包含的样本集合为空,不能划分。 
解决:把当前节点标记为叶节点,但其类别设定为其父节点所含样本最多的类别。(把父节点的样本分布作为当前节点的先验分布)。

决策树学习算法包括特征选择,决策树的生成与决策树的剪枝过程。决策树学习常用的算法有ID3,C4.5,CART。

2 特征选择

直观上,如果一个特征具有更好的分类能力,或者说,按照这一特征将训练数据集分割成子集,使得各个子集在当前条件下有最好的分类,那么就更应该选择这个特征。信息增益就能够很好的表示这一直观准则。
:表示随机变量不确定性的度量。设 X X X是一个取有限个值的离散随机变量。其概率分布为:
P ( X = x i ) = p i , i = 1 , 2 , … , n P(X=x_i )=p_i,i=1,2,…,n P(X=xi)=pi,i=1,2,,n
则随机变量X的熵定义为:
H ( X ) = − ∑ i = 1 n p i l o g ⁡ p i H(X)=-∑_{i=1}^{n} p_i log⁡p_i H(X)=i=1npilogpi
通常式(2)中的对数以2为底或以e为底(自然对数)。有定义可知,熵只依赖于 X X X的分布,而与 X X X的取值无关,所以也可将 X X X的熵记作 H ( p ) H(p) H(p),即:
H ( p ) = − ∑ i = 1 n p i l o g ⁡ p i H(p)= -∑_{i=1}^{n}p_i log⁡p_i H(p)=i=1npilogpi

条件熵 H ( Y ∣ X ) H(Y|X) H(YX)表示在已知随机变量 X X X的条件下随机变量 Y Y Y的不确定性。随机变量 X X X给定的条件下随机变量 Y Y Y的条件熵 H ( Y ∣ X ) H(Y|X) H(YX),定义为 X X X给定条件下 Y Y Y的条件概率分布的熵对 X X X的数学期望
H ( Y │ X ) = ∑ i = 1 n p i H ( Y │ X = x i ) H(Y│X)=∑_{i=1}^{n}p_i H(Y│X=x_i ) H(YX)=i=1npiH(YX=xi)
这里, p i = P ( X = x i ) , i = 1 , 2 , … , n p_i=P(X=x_i ),i=1,2,…,n pi=P(X=xi),i=1,2,,n
当熵和条件熵中的概率由数据估计得到时,所对应的熵与条件熵分别成为经验熵和经验条件熵,此时,如果有0概率,令 0 l o g 0 = 0 0log0=0 0log0=0
信息增益表示得知特征 X X X的信息而使得类 Y Y Y的信息的不确定性减少的程度。

定义2(信息增益)
特征 A A A对训练数据集 D D D的信息增益 g ( D , A ) g(D,A) g(D,A),定义为集合 D D D的经验熵 H ( D ) H(D) H(D)与特征 A A A给定条件下 D D D的经验条件熵 H ( D ∣ A ) H(D|A) H(DA)之差,即: g ( D , A ) = H ( D ) − H ( D │ A ) g(D,A)=H(D)-H(D│A) g(D,A)=H(D)H(DA)

一般的,熵 H ( Y ) H(Y) H(Y)与条件熵 H ( Y ∣ X ) H(Y|X) H(YX)之差称为互信息。决策树学习中的信息增益等价于训练数据集中类与特征的互信息。
信息增益大的特征具有更强的分类能力。

设训练数据集为 D D D ∣ D ∣ |D| D表示其样本容量,即样本个数。设有 K K K个类 C k , k = 1 , 2 , … , K , ∣ C k ∣ C_k,k=1,2,…,K,|C_k | Ckk=1,2,,KCk为属于类 C k C_k Ck的样本个数, ∑ k = 1 K ∣ C k ∣ = ∣ D ∣ ∑_{k=1}^K|C_k | =|D| k=1KCk=D。设特征 A A A n n n个不同的取值 a 1 , a 2 , … , a n {a_1,a_2,…,a_n } a1,a2,,an,根据特征 A A A的取值将 D D D划分为 n n n个子集 D 1 , D 2 , … , D n , ∣ D i ∣ D_1,D_2,…,D_n,|D_i | D1,D2,,DnDi D i D_i Di的样本个数, ∑ i = 1 n ∣ D i ∣ = ∣ D ∣ ∑_{i=1}^{n}|D_i |=|D| i=1nDi=D。记子集 D i D_i Di中属于类 C k C_k Ck的样本的集合为 D i k D_{ik} Dik,即 D i k = D i ∩ C k , ∣ D i k ∣ D_{ik}=D_i∩C_k,|D_{ik} | Dik=DiCkDik D i k D_{ik} Dik的样本个数。
在这里插入图片描述

# entropy.py
from math import log
# 熵
def calc_ent(datasets):
    data_length = len(datasets)
    label_count = {}
    for i in range(data_length):
        label = datasets[i][-1]
        if label not in label_count:
            label_count[label] = 0
        label_count[label] += 1
ent = -sum([(p / data_length) * log(p / data_length, 2) 
for p in label_count.values()])
    return ent

# 经验条件熵
def cond_ent(datasets, axis=0):
    data_length = len(datasets)
    feature_sets = {}
    for i in range(data_length):
        feature = datasets[i][axis]
        if feature not in feature_sets:
            feature_sets[feature] = []
        feature_sets[feature].append(datasets[i])
    cond_ent = sum([(len(p) / data_length) * calc_ent(p) for p in feature_sets.values()])
    return cond_ent

# 信息增益
def info_gain(ent, cond_ent):
    return ent - cond_ent

def info_gain_train(datasets, labels):
    count = len(datasets[0]) - 1
    ent = calc_ent(datasets)
    best_feature = []
    for c in range(count):
        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))
        best_feature.append((c, c_info_gain))
        print('特征({})-info_gain = {:.3f}'.format(labels[c], c_info_gain))
    # 比较大小
    best_ = max(best_feature, key=lambda x: x[-1])
    print("----------信息增益熵------------")
    print('特征({})的信息增益最大,选择为根节点特征'.format(labels[best_[0]]))
    return best_

信息增益准则对可取值数目较多的属性有所偏好,为减少这种偏好可能带来的不利影响,使用信息增益比来选择划分最优属性。

定义3(信息增益比) 特征 A A A对训练数据集 D D D的信息增益比 g R ( D , A ) g_R (D,A) gR(D,A)定义为其信息增益 g ( D , A ) g(D,A) g(D,A)与训练数据集 D D D关于特征 A A A的值的熵 H A ( D ) H_A (D) HA(D)之比,即 g R ( D , A ) = g ( D , A ) / ( H A ( D ) g_R (D,A)=g(D,A)/(H_A (D) gR(D,A)=g(D,A)/(HA(D) 其中, H A ( D ) = − ∑ i = 1 n ∣ D i ∣ ∣ D ∣ l o g 2 ⁡ ∣ D i ∣ ∣ D ∣ H_A (D)=-∑_{i=1}^{n}\frac{ |D_i |}{|D|}log_2⁡\frac{ |D_i |}{|D|} HA(D)=i=1nDDilog2DDi n n n是特征 A A A取值的个数。

def HA(datasets, axis=0):
    data_length = len(datasets)
    feature_sets = {}
    for i in range(data_length):
        feature = datasets[i][axis]
        if feature not in feature_sets:
            feature_sets[feature] = []
        feature_sets[feature].append(datasets[i])
    h_a = -sum([(len(p) / data_length) * log(len(p) / data_length, 2) for p in feature_sets.values()])
    return h_a

def info_gain_train(datasets, labels):
    count = len(datasets[0]) - 1
    ent = calc_ent(datasets)
    best_feature = []
    for c in range(count):
        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c)) / HA(datasets, axis=c)  # 信息增益比
        best_feature.append((c, c_info_gain))
    # 比较大小
    best_ = max(best_feature, key=lambda x: x[-1])
    print("----------信息增益比------------")
    print('特征({})的信息增益比最大,选择为根节点特征'.format(labels[best_[0]]))
    return best_

3. 决策树的生成

3.1 ID3算法

ID3算法的核心是在决策树各个节点上应用信息增益准则选择特征,递归地构建决策树。具体方法是:从根节点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子节点;再对子节点递归地调用以上方法,构建决策树;直到所有特征的信息增益军很小或没有特征可以选择为止。
在这里插入图片描述

# ID3.py
# 树节点
import entropy
import numpy as np

class Node:
    def __init__(self, root=True, label=None, feature_name=None, feature=None):
        self.root = root
        self.label = label
        self.feature_name = feature_name
        self.feature = feature
        self.tree = {}
        self.result = {
            'label': self.label,
            'feature': self.feature,
            'tree': self.tree
        }

    def __repr__(self):
        # 显示属性 更直观点:输出的格式
        return '{}'.format(self.result)

    def add_node(self, val, node):
        self.tree[val] = node

    def predict(self, features):
        if self.root is True:
            return self.label
        return self.tree[features[self.feature]].predict(features)

class DTree:
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon
        self._tree = {}

    def train(self, train_data, labels):
        """
        Input: 数据集D(DataFrame格式),特征集A,阈值eta
        Output: 决策树T
        """
        # _ 代表出去类别外的所有特征,y_train 代表类别; feature 代表特征标签
        _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[:-1]

        # 1. 若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为节点的类标记,返回T
        # value_counts() 查看表格某列中有多少个不同值的快捷方法
        if len(y_train.value_counts()) == 1:
            return Node(root=True, label=y_train.iloc[0])

        # 2. 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T
        if len(features) == 0:
            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])

        # 3. 计算最大信息增益,同5.1,Ag为信息增益的最大特征
        max_feature, max_info_gain = entropy.info_gain_train(np.array(train_data), labels)
        max_feature_name = features[max_feature]

        # 4. Ag的信息增益小于阈值eta,则置节点T为单节点树,并将D中是实例数最大的类Ck作为该结点的类标记,返回T
        if max_info_gain < self.epsilon:
            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])

        # 构建Ag子集
        node_tree = Node(root=False, feature_name=max_feature_name, feature=max_feature)
        feature_list = train_data[max_feature_name].value_counts().index
        for f in feature_list:
            sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1)
            # 生成递归树
            sub_tree = self.train(sub_train_df, labels)
            node_tree.add_node(f, sub_tree)

        return node_tree

    def fit(self, train_data, labels):
        self._tree = self.train(train_data, labels)
        return self._tree

    def predict(self, x_test):
        return self._tree.predict(x_test)

3.2 C4.5算法

使用信息增益比代替信息增益
在这里插入图片描述

4 CART算法

分类与回归树(CART)模型是应用广泛的决策树学习方法。CART同样由特征选择、树的生成及剪枝组成,即可分类也可回归。

4.1 CART生成

4.1.1 回归树的生成

在这里插入图片描述

4.1.2 分类树的生成

分类树用基尼指数选择最优特征,同时决定该特征的最优二值切分点。

定义4(基尼指数) 分类问题中,假设有K个类,样本点属于第k类的概率为p_k,则概率分布的基尼指数定义为
G i n i ( p ) = ∑ k = 1 K p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini(p)=∑_{k=1}^Kp_k (1-p_k )=1-∑_{k=1}^Kp_k^2 Gini(p)=k=1Kpk(1pk)=1k=1Kpk2

对于二分类问题,若样本点属于第1各类的概率为p,则概率分布的基尼指数定义为
G i n i ( p ) = 2 p ( 1 − p ) Gini(p)=2p(1-p) Gini(p)=2p(1p)
对于给定的样本集合D,其基尼指数为
G i n i ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ / ∣ D ∣ ) 2 Gini(D)=1-∑_{k=1}^K(|C_k |/|D| )^2 Gini(D)=1k=1K(Ck/D)2
这里, C k C_k Ck D D D中属于第 k k k类的样本子集, K K K是类的个数。
基尼指数 G i n i ( D ) Gini(D) Gini(D)表示集合 D D D的不确定性,基尼指数 G i n i ( D , A ) Gini(D,A) Gini(D,A)表示经 A = a A=a A=a分割后集合 D D D的不确定性。基尼指数值越大,样本集合的不确定性也就越大,这一点与熵相似。
G i n i ( D , A ) = ∣ D 1 ∣ / ∣ D ∣ G i n i ( D 1 ) + ∣ D 2 ∣ / ∣ D ∣ G i n i ( D 2 ) Gini(D,A)=|D_1 |/|D| Gini(D_1 )+|D_2 |/|D| Gini(D_2 ) Gini(D,A)=D1/DGini(D1)+D2/DGini(D2)
在这里插入图片描述

5 代码实现

5.1 Main.py

# CART决策树,使用基尼指数(Gini index)来选择划分属性
import Ex2
import CART

train_all_data, test_all_data, title = Ex2.daikuan()

''' 处理数据 '''
train_data, train_label = Ex2.classify_data_1(train_all_data, title)
test_data, test_label = Ex2.classify_data_1(test_all_data, title)

''' 训练 '''
decision_tree = CART.cart_tree(train_data, title, train_label)
print('训练的决策树是:')
CART.print_tree(decision_tree)
print('\n')

''' 预测 '''
answer = []
for x in test_data:
    answer.append(CART.predict(decision_tree, x))
print('决策树在测试数据集上的分类结果是:', answer)
print('测试数据集的正确类别信息应该是:', test_label)

''' 准确率 '''
accuracy = 0
for i in range(0, len(test_label)):
    if test_label[i] == answer[i]:
        accuracy += 1
accuracy /= len(test_label)
print('决策树在测试数据集上的分类正确率为:' + str(accuracy * 100) + '%')

5.2 Gini.py

def gini(labels=[]):
    """
    计算数据集的基尼值
    :param labels: 数据集的类别标签
    :return:
    """
    data_count = {}
    for label in labels:
        if data_count.__contains__(label):
            data_count[label] += 1
        else:
            data_count[label] = 1

    n = len(labels)
    if n == 0:
        return 0

    gini_value = 1
    for key, value in data_count.items():
        gini_value = gini_value - (value / n) * (value / n)

    return gini_value

def gini_index_basic(n, attr_labels={}):
    gini_value = 0
    for attribute, labels in attr_labels.items():
        gini_value = gini_value + len(labels) / n * gini(labels)

    return gini_value

def gini_index(attributes=[], labels=[], is_value=False):
    """
    计算一个属性的基尼指数
    :param attributes: 当前数据在该属性上的属性值列表
    :param labels:
    :param is_value:
    :return:
    """
    n = len(labels)
    attr_labels = {}
    gini_value = 0  # 最终要返回的结果
    split = None

    if is_value:  # 属性值是连续的数值
        sorted_attributes = attributes.copy()
        sorted_attributes.sort()
        split_points = []
        for i in range(0, n - 1):
            split_points.append((sorted_attributes[i + 1] + sorted_attributes[i]) / 2)

        split_evaluation = []
        for current_split in split_points:
            low_labels = []
            up_labels = []
            for i in range(0, n):
                if attributes[i] <= current_split:
                    low_labels.append(labels[i])
                else:
                    up_labels.append(labels[i])
            attr_labels = {'small': low_labels, 'large': up_labels}
            split_evaluation.append(gini_index_basic(n, attr_labels=attr_labels))
        gini_value = min(split_evaluation)
        split = split_points[split_evaluation.index(gini_value)]
    else:  # 属性值是离散的词
        for i in range(0, n):
            if attr_labels.__contains__(attributes[i]):
                temp_list = attr_labels[attributes[i]]
                temp_list.append(labels[i])
                attr_labels[attributes[i]] = temp_list
            else:
                temp_list = []
                temp_list.append(labels[i])
                attr_labels[attributes[i]] = temp_list
        gini_value = gini_index_basic(n, attr_labels=attr_labels)
    return gini_value, split

5.3 TreeNode.py

class TreeNode:
    """
    决策树结点类
    """
    current_index = 0

    def __init__(self, parent=None, attr_name=None, children=None, judge=None, split=None, data_index=None,
                 attr_value=None, rest_attribute=None):
        '''
        决策树节点类初始化方法
        :param parent: 父节点
        '''
        self.parent = parent  # 父节点,根节点的父节点为None
        self.attribute_name = attr_name  # 本节点上进行划分的属性名
        self.attribute_value = attr_value  # 本节点上划分属性的值,是与父节点的划分属性名相对应的
        self.children = children  # 孩子节点列表
        self.judge = judge  # 如果是叶子节点,需要给出判断
        self.split = split  # 如果是使用连续属性进行划分,需给出分割点
        self.data_index = data_index  # 对应训练数据集的训练索引号
        self.index = TreeNode.current_index  # 当前结点的索引号,方便输出时查看
        self.rest_attribute = rest_attribute  # 尚未使用的属性列表
        TreeNode.current_index += 1

    def to_string(self):
        '''
        用一个字符串来描述当前节点信息
        '''
        this_string = 'current index:' + str(self.index) + ';\n'
        if not (self.parent is None):
            parent_node = self.parent
            this_string = this_string + 'parent index:' + str(parent_node.index) + ';\n'
            this_string = this_string + str(parent_node.attribute_name) + ': ' + str(self.attribute_value) + ';\n'
        this_string = this_string + 'data: ' + str(self.data_index) + ';\n'
        if not (self.children is None):
            this_string = this_string + 'select attribute is: ' + str(self.attribute_name) + ';\n'
            child_list = []
            for child in self.children:
                child_list.append(child.index)
            this_string = this_string + 'children: ' + str(child_list)
        if not (self.judge is None):
            this_string = this_string + 'label: ' + self.judge
        return this_string

5.4 CART.py

# CART决策树,使用基尼指数(Gini index)来选择划分属性
import TreeNode
import Gini

def cart_tree(Data, title, label):
    """
    生成一棵 CART 决策树
    :param Data: 数据集,每个样本是一个 dict(属性名:属性值),整个Data是个大的list
    :param title: 每个属性的名字,如:色泽、含糖率等
    :param label: 存储的是每个样本的类别
    :return:
    """
    n = len(Data)
    rest_title = title.copy()
    root_data = [i for i in range(0, n)]
    root_node = TreeNode.TreeNode(data_index=root_data, rest_attribute=rest_title)
    finish_node(root_node, Data, label)

    return root_node

def is_number(s):
    """
    判断一个字符串是否为数字,如果为数字我们认为这个属性的值是连续的
    """
    try:
        float(s)
        return True
    except ValueError:
        pass
    return False

def finish_node(current_node=TreeNode.TreeNode(), data=[], label=[]):
    """
    完成一个节点上的计算
    :param current_node: 当前计算的节点
    :param data: 数据集
    :param label: 数据集的label
    :return:
    """
    n = len(label)

    # 判断当前节点中的数据是否属于同一类
    one_class = True
    this_data_index = current_node.data_index

    for i in this_data_index:
        for j in this_data_index:
            if label[i] != label[j]:
                one_class = False
                break
        if not one_class:
            break
    if one_class:
        current_node.judge = label[this_data_index[0]]
        return

    rest_title = current_node.rest_attribute  # 侯选属性
    if len(rest_title) == 0:
        # 如果候选属性为空,则是个叶子结点。需要选最多的那个类作为该节点的类
        label_count = {}
        temp_data = current_node.data_index
        for index in temp_data:
            if label in temp_data:
                if label_count.__contains__(label[index]):
                    label_count[label[index]] += 1
                else:
                    label_count[label[index]] = 1
        final_label = max(label_count)
        current_node.judge = final_label
        return

    title_gini = {}  # 记录每个属性的基尼指数
    title_spilt_value = {}  # 记录每个属性的分隔值,如果是连续属性则为分隔值,如果是离散属性则为None
    for title in rest_title:
        attr_values = []
        current_label = []
        for index in current_node.data_index:
            this_data = data[index]
            attr_values.append(this_data[title])
            current_label.append(label[index])
        temp_data = data[0]
        this_gain, this_split_value = Gini.gini_index(attr_values, current_label,
                                                 is_number(temp_data[title]))  # 如果属性值为数字,则认为是连续的
        title_gini[title] = this_gain
        title_spilt_value[title] = this_split_value

    best_attr = min(title_gini, key=title_gini.get)  # 基尼指数最小的属性名
    current_node.attribute_name = best_attr
    current_node.split = title_spilt_value[best_attr]
    rest_title.remove(best_attr)

    a_data = data[0]
    if is_number(a_data[best_attr]):  # 如果是该属性的值为连续值
        split_value = title_spilt_value[best_attr]
        small_data = []
        large_data = []
        for index in current_node.data_index:
            this_data = data[index]
            if this_data[best_attr] <= split_value:
                small_data.append(index)
            else:
                large_data.append(index)
        small_str = ' <= ' + str(split_value)
        large_str = ' > ' + str(split_value)
        small_child = TreeNode.TreeNode(parent=current_node, data_index=small_data, attr_value=small_str,
                                        rest_attribute=rest_title.copy())
        large_child = TreeNode.TreeNode(parent=current_node, data_index=large_data, attr_value=large_str,
                                        rest_attribute=rest_title.copy())
        current_node.children = [small_child, large_child]

    else:  # 如果属性的值是离散的
        best_titlevalue_dict = {}  # key是属性的取值,value是个list记录所包含的样本序号
        for index in current_node.data_index:
            this_data = data[index]
            if best_titlevalue_dict.__contains__(this_data[best_attr]):
                temp_list = best_titlevalue_dict[this_data[best_attr]]
                temp_list.append(index)
            else:
                temp_list = [index]
                best_titlevalue_dict[this_data[best_attr]] = temp_list

        children_list = []
        for key, index_list in best_titlevalue_dict.items():
            a_child = TreeNode.TreeNode(parent=current_node, data_index=index_list, attr_value=key,
                                        rest_attribute=rest_title.copy())
            children_list.append(a_child)
        current_node.children = children_list

    # print(current_node,to_string())
    for child in current_node.children:  # 递归
        finish_node(child, data, label)




def print_tree(root=TreeNode.TreeNode()):
    """
    打印输出一棵树
    :param root:  根节点
    :return:
    """
    node_list = [root]
    while (len(node_list) > 0):
        current_node = node_list[0]
        print('------------------------------------------------')
        print(current_node.to_string())
        print('------------------------------------------------')
        children_list = current_node.children
        if not (children_list is None):
            for child in children_list:
                node_list.append(child)
        node_list.remove(current_node)

def predict(decision_tree=TreeNode.TreeNode(), x={}):
    """
    使用决策树判断一个样本数据的类别标签
    :param decision_tree: 训练好的决策树的根节点
    :param x: 要进行判断的样本
    :return:
    """
    current_node = decision_tree
    while current_node.judge is None:  # 是否为叶子节点
        if current_node.split is None:  # 离散属性
            can_judge = False  # 如果训练数据集不够大,测试数据集中可能会有在训练数据集中没有出现过的属性值
            for child in current_node.children:
                if child.attribute_value == x[current_node.attribute_name]:
                    current_node = child
                    can_judge = True
                    break
            if not can_judge:
                return None
        else:
            child_list = current_node.children
            if x[current_node.attribute_name] <= current_node.split:
                current_node = child_list[0]
            else:
                current_node = child_list[1]
    return current_node.judge

6 决策树的剪枝

在决策树学习中将以生成的树进行简化的过程称为剪枝。具体地,剪枝从已生成的树上裁掉一些子树或叶节点,并将其根节点或父节点作为新的叶节点,从而简化分类树的模型。避免形成“过拟合”的现象。
决策树的剪枝往往通过极小化决策树整体的损失函数或代价函数来实现。设树 T T T的叶结点个数为 ∣ T ∣ |T| T t t t是树 T T T的叶结点,该叶节点有 N t N_t Nt个样本点,其中 k k k类样本点有 N t k N_{tk} Ntk个, k = 1 , 2 , … , K , H t ( T ) k=1,2,…,K,H_t (T) k=1,2,,KHt(T)为叶节点 t t t上的经验熵, α ≥ 0 α≥0 α0为参数,则决策树学习的损失函数可以定义为:
C α ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) + α ∣ T ∣ C_α (T)=∑_{t=1}^{|T|}N_t H_t (T)+α|T| Cα(T)=t=1TNtHt(T)+αT
其中经验熵为:
H t ( T ) = − ∑ k ⁡ N t k N t l o g ⁡ N t k N t H_t (T)=-∑_k\frac{⁡N_{tk}}{N_t} log\frac{⁡N_{tk}}{N_t} Ht(T)=kNtNtklogNtNtk
C ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) = − ∑ t = 1 ∣ T ∣ ∑ k = 1 K N t k l o g ⁡ N t k N t C(T)=∑_{t=1}^{|T|}N_t H_t (T)=-∑_{t=1}^{|T|}∑_{k=1}^KN_{tk} log \frac{⁡N_{tk}}{N_t} C(T)=t=1TNtHt(T)=t=1Tk=1KNtklogNtNtk,有:
C α ( T ) = C ( T ) + α ∣ T ∣ C_α (T)=C(T)+α|T| Cα(T)=C(T)+αT
其中, C ( T ) C(T) C(T)表示模型队训练数据的预测误差,及模型与训练数据的拟合程度, ∣ T ∣ |T| T表示模型复杂度。
在这里插入图片描述
决策树的剪枝又分为预剪枝和后剪枝。

6.1 预剪枝

预剪枝是指在决策树生成的过程中,对每个结点在画分前先进行估计,若当前节点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶节点。

def finish_node(current_node=TreeNode.TreeNode(), data=[], label=[], test_data=[], test_label=[], flag=0, ):
    """
    完成一个节点上的计算
    :param current_node: 当前计算的节点
    :param data: 数据集
    :param label: 数据集的label
    :param flag: "0"--不进行剪枝  “1”--进行预剪枝    默认不剪枝
    :return:
    """
    n = len(label)

    # 判断当前节点中的数据是否属于同一类
    one_class = True
    this_data_index = current_node.data_index

    for i in this_data_index:
        for j in this_data_index:
            if label[i] != label[j]:
                one_class = False
                break
        if not one_class:
            break
    if one_class:
        current_node.judge = label[this_data_index[0]]
        return

    rest_title = current_node.rest_attribute  # 侯选属性
    if len(rest_title) == 0:
        # 如果候选属性为空,则是个叶子结点。需要选最多的那个类作为该节点的类
        label_count = {}
        temp_data = current_node.data_index
        for index in temp_data:
            if label in temp_data:
                if label_count.__contains__(label[index]):
                    label_count[label[index]] += 1
                else:
                    label_count[label[index]] = 1
        final_label = max(label_count)
        current_node.judge = final_label
        return

    # 预剪枝
    if flag == 1:
        data_count = {}
        for index in current_node.data_index:
            if data_count.__contains__(label[index]):
                data_count[label[index]] += 1
            else:
                data_count[label[index]] = 1
        before_judge = max(data_count, key=data_count.get)
        current_node.judge = before_judge
        before_accuracy = Predictions_results.current_accuracy_1(current_node, test_data, test_label)

    title_gini = {}  # 记录每个属性的基尼指数
    title_spilt_value = {}  # 记录每个属性的分隔值,如果是连续属性则为分隔值,如果是离散属性则为None
    for title in rest_title:
        attr_values = []
        current_label = []
        for index in current_node.data_index:
            this_data = data[index]
            attr_values.append(this_data[title])
            current_label.append(label[index])
        temp_data = data[0]
        this_gain, this_split_value = Gini.gini_index(attr_values, current_label,
                                                      is_number(temp_data[title]))  # 如果属性值为数字,则认为是连续的
        title_gini[title] = this_gain
        title_spilt_value[title] = this_split_value

    best_attr = min(title_gini, key=title_gini.get)  # 基尼指数最小的属性名
    current_node.attribute_name = best_attr
    current_node.split = title_spilt_value[best_attr]
    rest_title.remove(best_attr)

    a_data = data[0]
    if is_number(a_data[best_attr]):  # 如果是该属性的值为连续值
        split_value = title_spilt_value[best_attr]
        small_data = []
        large_data = []
        for index in current_node.data_index:
            this_data = data[index]
            if this_data[best_attr] <= split_value:
                small_data.append(index)
            else:
                large_data.append(index)
        small_str = ' <= ' + str(split_value)
        large_str = ' > ' + str(split_value)
        small_child = TreeNode.TreeNode(parent=current_node, data_index=small_data, attr_value=small_str,
                                        rest_attribute=rest_title.copy())
        large_child = TreeNode.TreeNode(parent=current_node, data_index=large_data, attr_value=large_str,
                                        rest_attribute=rest_title.copy())

        # 预剪枝
        if flag == 1:
            small_data_count = {}
            for index in small_child.data_index:
                if small_data_count.__contains__(label[index]):
                    small_data_count[label[index]] += 1
                else:
                    small_data_count[label[index]] = 1
            small_child_judge = max(small_data_count, key=small_data_count.get)
            small_child.judge = small_child_judge  # 临时添加的一个判断
            large_data_count = {}
            for index in large_child.data_index:
                if large_data_count.__contains__(label[index]):
                    large_data_count[label[index]] += 1
                else:
                    large_data_count[label[index]] = 1
            large_child_judge = max(large_data_count, key=large_data_count.get)
            large_child.judge = large_child_judge  # 临时添加的一个判断

        current_node.children = [small_child, large_child]

    else:  # 如果属性的值是离散的
        best_titlevalue_dict = {}  # key是属性的取值,value是个list记录所包含的样本序号
        for index in current_node.data_index:
            this_data = data[index]
            if best_titlevalue_dict.__contains__(this_data[best_attr]):
                temp_list = best_titlevalue_dict[this_data[best_attr]]
                temp_list.append(index)
            else:
                temp_list = [index]
                best_titlevalue_dict[this_data[best_attr]] = temp_list

        children_list = []
        for key, index_list in best_titlevalue_dict.items():
            a_child = TreeNode.TreeNode(parent=current_node, data_index=index_list, attr_value=key,
                                        rest_attribute=rest_title.copy())

            if flag == 0:
                children_list.append(a_child)
            elif flag == 1:  # 预剪枝
                temp_data_count = {}
                for index in index_list:
                    if temp_data_count.__contains__(label[index]):
                        temp_data_count[label[index]] += 1
                    else:
                        temp_data_count[label[index]] = 1
                temp_child_judge = max(temp_data_count, key=temp_data_count.get)
                a_child.judge = temp_child_judge
                children_list.append(a_child)
        current_node.children = children_list

    if flag == 0:
        for child in current_node.children:  # 递归
            finish_node(child, data, label)
    elif flag == 1:
        current_node.judge = None
        later_accuracy = Predictions_results.current_accuracy_1(current_node, test_data, test_label)
        print(str(current_node.index)+"处,不剪枝的正确率是 "+str(later_accuracy) +",剪枝的正确率是 "+str(before_accuracy))
        if before_accuracy > later_accuracy:
            current_node.children = None
            current_node.judge = before_judge
            # print(str(current_node.index)+"处进行剪枝")
            return
        else:
            # print(current_node.to_string())
            for child in current_node.children:  # 递归
                finish_node(child, data, label, test_data, test_label)

6.2 后剪枝

后剪枝则是先从训练集生成一棵完成的决策树,然后自底向上地对非叶结点进行考察,若将该节点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。

def post_pruning(decision_tree=TreeNode.TreeNode(), test_data=[], test_label=[], train_label=[]):
    """
    对决策树进行后剪枝
    :param decision_tree: 决策树根节点
    :param test_data: 测试数据集
    :param test_label: 测试数据集的标签
    :param train_label: 训练数据集的标签
    :return:
    """
    leaf_father = [] # 所有的孩子都是叶结点的结点集合

    bianli_list = []
    bianli_list.append(decision_tree)
    while len(bianli_list) > 0:
        current_node = bianli_list[0]
        children = current_node.children
        wanted = True
        if not (children is None):
            for child in children:
                bianli_list.append(child)
                temp_bool = (child.children is None)
                wanted = (wanted and temp_bool)
        else:
            wanted = False

        if wanted:
            leaf_father.append(current_node)
        bianli_list.remove(current_node)

    while len(leaf_father)>0:
        # 如果父结点为空,则剪枝完成。对于不需要进行剪枝操作的叶父节点,我们将其从leaf_father中删去
        current_node = leaf_father.pop()
        # 不进行剪枝在测试集上的正确率
        before_accuracy = Predictions_results.current_accuracy_1(root_node=decision_tree, test_data=test_data, test_label=test_label)

        data_index = current_node.data_index
        label_count = {}
        for index in data_index:
            if label_count.__contains__(index):
                label_count[train_label[index]] += 1
            else:
                label_count[train_label[index]] = 1
        current_node.judge = max(label_count, key=label_count.get)  # 如果进行剪枝当前结点应该做出的判断
        later_accuracy = Predictions_results.current_accuracy_1(root_node=decision_tree, test_data=test_data, test_label=test_label)

        if before_accuracy > later_accuracy:  # 不进行剪枝
            current_node.judge = None
        else:  # 进行剪枝
            current_node.children = None
            # 还需要检查是否需要对它的父节点进行判断
            parent_node = current_node.parent
            if not (parent_node is None):
                children_list = parent_node.children
                temp_bool = True
                for child in children_list:
                    if not (child.children is None):
                        temp_bool = False
                        break
                if temp_bool:
                    leaf_father.append(parent_node)
    return decision_tree

注:代码有参考大佬:麦克斯韦的妖精的一部分,原链接:https://blog.youkuaiyun.com/john_bian/article/details/100586245
稍后会将代码打包上传,稍等ing

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值