上一篇博客我们介绍了决策树的基本原理,今天我们来动手实现一个简单的决策树分类器。本文将基于 ID3 算法,使用 Python 手动实现决策树,并通过实例展示其用法。
决策树实现的核心步骤
实现一个决策树主要包括以下几个核心部分:
- 信息熵计算:衡量数据集的不确定性
- 信息增益计算:选择最佳分裂特征的依据
- 递归构建树:通过不断分裂节点构建完整的决策树
- 预测函数:使用构建好的树进行新样本预测
手动代码解析
1. 信息熵计算
信息熵是衡量数据集纯度的指标,熵越小表示数据集越纯:
def entropy(self, y):
"""计算信息熵"""
counts = Counter(y)
probabilities = [count / len(y) for count in counts.values()]
return -sum(p * np.log2(p) for p in probabilities if p > 0)
2. 信息增益计算
信息增益表示使用某个特征分裂后,数据集不确定性减少的程度:
def information_gain(self, X, y, feature_idx):
"""计算信息增益"""
original_entropy = self.entropy(y)
feature_values = X[:, feature_idx]
unique_values = np.unique(feature_values)
conditional_entropy = 0
for value in unique_values:
mask = (feature_values == value)
y_subset = y[mask]
conditional_entropy += (len(y_subset) / len(y)) * self.entropy(y_subset)
return original_entropy - conditional_entropy
3. 构建决策树
构建过程采用递归方式,主要步骤是:
- 检查停止条件(所有样本同类别或达到最大深度)
- 选择最佳分裂特征
- 按照特征值分裂数据集
- 递归构建子树
def build_tree(self, X, y, depth=0):
# 停止条件:所有样本属于同一类别
if len(np.unique(y)) == 1:
return {'leaf': True, 'class': y[0]}
# 停止条件:达到最大深度
if self.max_depth is not None and depth >= self.max_depth:
most_common = Counter(y).most_common(1)[0][0]
return {'leaf': True, 'class': most_common}
# 选择最佳分裂特征
best_feature = self.best_feature_to_split(X, y)
# 构建子树(代码省略)
4. 预测功能
预测时从根节点开始,根据样本的特征值逐层向下,直到到达叶节点得到预测结果:
def predict_sample(self, sample, tree):
if tree['leaf']:
return tree['class']
feature_value = sample[tree['feature']]
return self.predict_sample(np.delete(sample, tree['feature']),
tree['children'][feature_value])
示例应用
在代码的if __name__ == "__main__":部分,我们创建了一个简单的数据集来判断生物是否为鱼类,使用了三个特征:有鳍、有鳃、水生。
运行代码后,决策树会学习这些特征与 "是否为鱼类" 之间的关系,并对新样本进行预测。
进一步改进方向
这个实现很基础,实际应用中还可以进行以下改进:
- 增加剪枝功能,进一步防止过拟合
- 支持连续值特征
- 处理缺失值
- 实现 C4.5 或 CART 算法
使用 scikit-learn 库实现决策树
除了手动实现决策树,我们还可以使用 Python 的 scikit-learn 库来快速构建和使用决策树模型。scikit-learn 提供了高效、稳定的决策树实现,支持分类和回归任务,并且包含了许多实用的功能。
scikit-learn 决策树实现示例
下面我们使用 scikit-learn 库来实现一个决策树分类器,并与我们手动实现的版本进行对比:
使用scikit-learn实现决策树
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
# 1. 准备数据
# 我们使用与手动实现相同的数据集:判断是否为鱼类
# 特征:[有鳍, 有鳃, 水生],1表示是,0表示否
# 标签:1表示是鱼类,0表示不是
data = np.array([
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 1, 0],
[1, 0, 1, 0],
[1, 0, 0, 0]
])
X = data[:, :-1] # 特征
y = data[:, -1] # 标签
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 2. 创建并训练决策树模型
# 使用ID3算法的思路( criterion='entropy')
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=42)
clf.fit(X_train, y_train)
# 3. 模型评估
y_pred = clf.predict(X_test)
print("测试集预测结果:", y_pred)
print("测试集真实标签:", y_test)
print(f"准确率: {accuracy_score(y_test, y_pred):.2f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred))
# 4. 可视化决策树结构
feature_names = ['有鳍', '有鳃', '水生']
tree_rules = export_text(clf, feature_names=feature_names)
print("\n决策树规则:")
print(tree_rules)
# 5. 使用模型进行预测
test_samples = np.array([
[1, 1, 1], # 应该是鱼类
[0, 1, 0], # 应该不是鱼类
[1, 0, 1] # 应该不是鱼类
])
predictions = clf.predict(test_samples)
print("\n新样本预测结果:", predictions)
print("预测结果解释:", ["是鱼类" if p == 1 else "不是鱼类" for p in predictions])
scikit-learn 决策树代码解析
1. 主要参数说明
DecisionTreeClassifier类的关键参数:
criterion:特征选择标准,'entropy' 表示使用信息熵(ID3 算法思路),'gini' 表示使用基尼系数(CART 算法)max_depth:树的最大深度,用于防止过拟合min_samples_split:分裂内部节点所需的最小样本数min_samples_leaf:叶节点所需的最小样本数random_state:随机数种子,保证结果可重现
2. 模型训练与评估
scikit-learn 的 API 设计非常一致,主要步骤为:
- 创建模型对象(
DecisionTreeClassifier) - 使用
fit()方法训练模型 - 使用
predict()方法进行预测 - 使用评估指标(如准确率、分类报告)评估模型性能
3. 决策树可视化
export_text()函数可以将决策树以文本形式展示,便于我们理解模型的决策过程。对于更复杂的树,还可以使用plot_tree()函数进行图形化展示(需要 matplotlib 支持)。
与手动实现的对比
scikit-learn 实现的决策树具有以下优势:
- 功能完善:支持更多参数调整,如剪枝、处理缺失值等
- 效率更高:经过优化的 C 语言底层实现,处理大数据集速度更快
- 鲁棒性更好:对异常值和噪声数据有更好的处理能力
- 扩展性强:可以方便地与 scikit-learn 的其他模块(如交叉验证、网格搜索)结合使用
实际应用建议
在实际项目中,推荐使用 scikit-learn 的决策树实现,主要原因是:
- 代码简洁,开发效率高
- 经过了充分的测试和优化,稳定性好
- 提供了丰富的参数,可以精细调整模型
- 与 scikit-learn 的其他工具(如模型选择、预处理)无缝集成
2973

被折叠的 条评论
为什么被折叠?



