决策树的 Python 代码实现

上一篇博客我们介绍了决策树的基本原理,今天我们来动手实现一个简单的决策树分类器。本文将基于 ID3 算法,使用 Python 手动实现决策树,并通过实例展示其用法。

决策树实现的核心步骤

实现一个决策树主要包括以下几个核心部分:

  1. 信息熵计算:衡量数据集的不确定性
  2. 信息增益计算:选择最佳分裂特征的依据
  3. 递归构建树:通过不断分裂节点构建完整的决策树
  4. 预测函数:使用构建好的树进行新样本预测

手动代码解析

1. 信息熵计算

信息熵是衡量数据集纯度的指标,熵越小表示数据集越纯:

def entropy(self, y):
    """计算信息熵"""
    counts = Counter(y)
    probabilities = [count / len(y) for count in counts.values()]
    return -sum(p * np.log2(p) for p in probabilities if p > 0)

2. 信息增益计算

信息增益表示使用某个特征分裂后,数据集不确定性减少的程度:

def information_gain(self, X, y, feature_idx):
    """计算信息增益"""
    original_entropy = self.entropy(y)
    feature_values = X[:, feature_idx]
    unique_values = np.unique(feature_values)
    
    conditional_entropy = 0
    for value in unique_values:
        mask = (feature_values == value)
        y_subset = y[mask]
        conditional_entropy += (len(y_subset) / len(y)) * self.entropy(y_subset)
    
    return original_entropy - conditional_entropy

3. 构建决策树

构建过程采用递归方式,主要步骤是:

  • 检查停止条件(所有样本同类别或达到最大深度)
  • 选择最佳分裂特征
  • 按照特征值分裂数据集
  • 递归构建子树
def build_tree(self, X, y, depth=0):
    # 停止条件:所有样本属于同一类别
    if len(np.unique(y)) == 1:
        return {'leaf': True, 'class': y[0]}
    
    # 停止条件:达到最大深度
    if self.max_depth is not None and depth >= self.max_depth:
        most_common = Counter(y).most_common(1)[0][0]
        return {'leaf': True, 'class': most_common}
    
    # 选择最佳分裂特征
    best_feature = self.best_feature_to_split(X, y)
    # 构建子树(代码省略)

4. 预测功能

预测时从根节点开始,根据样本的特征值逐层向下,直到到达叶节点得到预测结果:

def predict_sample(self, sample, tree):
    if tree['leaf']:
        return tree['class']
    
    feature_value = sample[tree['feature']]
    return self.predict_sample(np.delete(sample, tree['feature']), 
                              tree['children'][feature_value])

示例应用

在代码的if __name__ == "__main__":部分,我们创建了一个简单的数据集来判断生物是否为鱼类,使用了三个特征:有鳍、有鳃、水生。

运行代码后,决策树会学习这些特征与 "是否为鱼类" 之间的关系,并对新样本进行预测。

进一步改进方向

这个实现很基础,实际应用中还可以进行以下改进:

  1. 增加剪枝功能,进一步防止过拟合
  2. 支持连续值特征
  3. 处理缺失值
  4. 实现 C4.5 或 CART 算法

使用 scikit-learn 库实现决策树

除了手动实现决策树,我们还可以使用 Python 的 scikit-learn 库来快速构建和使用决策树模型。scikit-learn 提供了高效、稳定的决策树实现,支持分类和回归任务,并且包含了许多实用的功能。

scikit-learn 决策树实现示例

下面我们使用 scikit-learn 库来实现一个决策树分类器,并与我们手动实现的版本进行对比:

使用scikit-learn实现决策树

import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

# 1. 准备数据
# 我们使用与手动实现相同的数据集:判断是否为鱼类
# 特征:[有鳍, 有鳃, 水生],1表示是,0表示否
# 标签:1表示是鱼类,0表示不是
data = np.array([
    [1, 1, 1, 1],
    [1, 1, 1, 1],
    [1, 1, 0, 0],
    [0, 1, 1, 0],
    [0, 1, 1, 0],
    [0, 0, 1, 0],
    [1, 0, 1, 0],
    [1, 0, 0, 0]
])

X = data[:, :-1]  # 特征
y = data[:, -1]   # 标签

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 2. 创建并训练决策树模型
# 使用ID3算法的思路( criterion='entropy')
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=42)
clf.fit(X_train, y_train)

# 3. 模型评估
y_pred = clf.predict(X_test)
print("测试集预测结果:", y_pred)
print("测试集真实标签:", y_test)
print(f"准确率: {accuracy_score(y_test, y_pred):.2f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred))

# 4. 可视化决策树结构
feature_names = ['有鳍', '有鳃', '水生']
tree_rules = export_text(clf, feature_names=feature_names)
print("\n决策树规则:")
print(tree_rules)

# 5. 使用模型进行预测
test_samples = np.array([
    [1, 1, 1],  # 应该是鱼类
    [0, 1, 0],  # 应该不是鱼类
    [1, 0, 1]   # 应该不是鱼类
])

predictions = clf.predict(test_samples)
print("\n新样本预测结果:", predictions)
print("预测结果解释:", ["是鱼类" if p == 1 else "不是鱼类" for p in predictions])
    

scikit-learn 决策树代码解析

1. 主要参数说明

DecisionTreeClassifier类的关键参数:

  • criterion:特征选择标准,'entropy' 表示使用信息熵(ID3 算法思路),'gini' 表示使用基尼系数(CART 算法)
  • max_depth:树的最大深度,用于防止过拟合
  • min_samples_split:分裂内部节点所需的最小样本数
  • min_samples_leaf:叶节点所需的最小样本数
  • random_state:随机数种子,保证结果可重现

2. 模型训练与评估

scikit-learn 的 API 设计非常一致,主要步骤为:

  1. 创建模型对象(DecisionTreeClassifier
  2. 使用fit()方法训练模型
  3. 使用predict()方法进行预测
  4. 使用评估指标(如准确率、分类报告)评估模型性能

3. 决策树可视化

export_text()函数可以将决策树以文本形式展示,便于我们理解模型的决策过程。对于更复杂的树,还可以使用plot_tree()函数进行图形化展示(需要 matplotlib 支持)。

与手动实现的对比

scikit-learn 实现的决策树具有以下优势:

  1. 功能完善:支持更多参数调整,如剪枝、处理缺失值等
  2. 效率更高:经过优化的 C 语言底层实现,处理大数据集速度更快
  3. 鲁棒性更好:对异常值和噪声数据有更好的处理能力
  4. 扩展性强:可以方便地与 scikit-learn 的其他模块(如交叉验证、网格搜索)结合使用

实际应用建议

在实际项目中,推荐使用 scikit-learn 的决策树实现,主要原因是:

  1. 代码简洁,开发效率高
  2. 经过了充分的测试和优化,稳定性好
  3. 提供了丰富的参数,可以精细调整模型
  4. 与 scikit-learn 的其他工具(如模型选择、预处理)无缝集成
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值