在机器学习中,决策树被认为是最直观且可解释的算法之一。由于其模仿人类决策过程的能力,它成为解决广泛问题的流行选择。让我们深入了解并揭开决策树的神秘面纱。
文章目录
1 介绍
决策树是一种灵活且可解释的监督学习算法,是许多经典机器学习算法(如随机森林、Bagging 和梯度提升决策树)的基础算法,核心思想是将数据表示为一棵树:
-
根节点:包含初始的决策规则,例如“是否运动”
-
内部节点:表示针对某个属性进行判断的条件测试
-
分支:根据条件分裂出后续的决策路径
-
叶节点(终端节点):表示分类标签或预测值
-
分裂:将节点分成多个子节点的过程,基于特定条件或规则
-
剪枝:与分裂相反的过程,通过移除不必要的决策规则简化树的结构,从而降低算法的复杂性并提高泛化能力。
如今,决策树被广泛用于分类和回归任务的预测建模。它们有时也被称为 CART(Classification and Regression Trees
),即分类与回归树。决策树的直观性和灵活性使其成为机器学习中的重要工具。
决策树的类型
根据目标变量的类型,决策树分为两类:
-
分类变量决策树(Categorical Variable Decision Trees)
- 用于分类目标变量。例如,预测计算机价格为“低、中、高”三类。
- 特征可以包括显示器类型、音响质量、RAM 和 SSD。决策树会根据这些特征构建规则,每个数据点通过决策路径到达叶节点后,最终预测为三类之一。
-
连续变量决策树(Continuous Variable Decision Trees)
- 用于回归目标变量。例如,根据房屋特征(如面积、位置等)预测房价。
- 决策树根据输入特征生成一组规则,用于预测连续值。
决策树的剪枝
决策树容易过拟合,尤其是树的深度过大时,会导致对训练数据拟合过度、泛化能力差。剪枝是控制树复杂度的关键技术,包括:
- 前剪枝:在树构建过程中设置限制(如最大深度、最小样本数),提前停止分裂。
- 后剪枝:先生成完整的决策树,再通过移除不重要的分支简化结构,提高模型的泛化能力。
优缺点
决策树的优点
- 易于理解和解释:决策树的规则直观清晰,树状结构可以被可视化。
- 对数据准备需求较低:决策树不需要数据归一化、创建虚拟变量或删除空值(但在一些库中,如scikit-learn不支持处理缺失值)。
- 同时支持数值型和类别型数据:能够灵活处理不同类型的特征。
- 支持多输出问题:适用于多目标预测。
- 白盒模型:结果易于解释,可以直接看到每一步的决策依据。
- 可通过统计检验验证模型:能够评估模型的可靠性。
决策树的缺点
- 容易生成过于复杂的树:这会导致模型无法很好地泛化数据(过拟合)。需要通过剪枝、设置叶节点所需的最小样本数或限制最大树深度来避免该问题。
- 对数据变化敏感:数据中小的变化可能会生成完全不同的决策树。通过在集成方法中使用决策树可以缓解这一问题。
- 对不平衡数据敏感:如果某些类别占主导地位,决策树会生成偏向性的结果。因此,在拟合决策树之前,建议对数据集进行平衡处理。
2 构建决策树
决策树的目标是通过节点之间的最佳分割来优化数据划分,使其能被正确分类。这些决策规则直接影响算法的性能。
构建决策树的假设
- 初始节点:将整个数据视为根节点,之后使用算法进行分割。
- 特征值类型:特征值被认为是分类的。如果是连续值,则需要离散化。
- 记录分布:数据根据属性值递归分布。
- 属性排序:通过统计方法确定根节点或内部节点的属性顺序。
Gini 指数、信息增益 和 卡方检验 是决策树中常用的属性选择标准,用于确定最佳分裂属性,提高决策树的准确性和泛化能力。
2.1 Gini不纯度
Gini不纯度(Gini Impurity
)用于衡量一个节点中随机选取样本被错误分类的概率,它提供了模型分割与理想纯净分割的差异程度。 数学公式如下:
G i n i = 1 − ∑ i = 1 n ( p i ) 2 Gini = 1 - \sum_{i=1}^{n} (p_i)^2 Gini=1−i=1∑n(pi)2
- 其中 p i p_i pi 是某类别元素在节点中的比例。
Gini不纯度的值介于 0 0 0 和 1 1 1 之间:
- 0表示所有元素属于同一个类别,分割纯净
- 0.5表示元素平均分布到不同类别中
- 1表示元素完全随机分布到各个类别中
例子
假设有以下数据表,每行代表两个变量
V
a
r
1
Var1
Var1 和
V
a
r
2
Var2
Var2 以及分类标签
C
l
a
s
s
Class
Class。
Class | Var 1 | Var 2 |
---|---|---|
A | 0 | 33 |
A | 0 | 54 |
A | 0 | 56 |
A | 1 | 50 |
B | 1 | 55 |
B | 1 | 31 |
B | 0 | -4 |
B | 1 | 77 |
B | 0 | 49 |
现在评估特征 V a r 1 Var1 Var1 对数据分类的纯度:
-
分割基准:对于 V a r 1 Var1 Var1,其中 4 / 10 4/10 4/10 的实例等于 1 1 1, 6 / 10 6/10 6/10 的实例等于 0 0 0。
-
计算子集的 Gini 指数:
-
当 V a r 1 = = 1 Var1 == 1 Var1==1, C l a s s = = A Class == A Class==A 占 1 / 4 1/4 1/4, C l a s s = = B Class == B Class==B 占 3 / 4 3/4 3/4:
G i n i = 1 − ( ( 1 / 4 ) 2 + ( 3 / 4 ) 2 ) = 0.375 Gini = 1 - ((1/4)^2 + (3/4)^2) = 0.375 Gini=1−((1/4)2+(3/4)2)=0.375
-
当 V a r 1 = = 0 Var1 == 0 Var1==0, C l a s s = = A Class == A Class==A 占 4 / 6 4/6 4/6, C l a s s = = B Class == B Class==B 占 2 / 6 2/6 2/6:
G i n i = 1 − ( ( 4 / 6 ) 2 + ( 2 / 6 ) 2 ) = 0.4444 Gini = 1 - ((4/6)^2 + (2/6)^2) = 0.4444 Gini=1−((4/6)2+(2/6)2)=0.4444
-
-
加权求和: 按照基准比例计算整体 Gini 指数
4 / 10 ∗ 0.375 + 6 / 10 ∗ 0.4444 = 0.41667 4/10 * 0.375 + 6/10 * 0.4444 = 0.41667 4/10∗0.375+6/10∗0.4444=0.41667
通俗解释:假设你要把一堆球按照颜色分成两堆,分得越干净越好。
- V a r 1 Var1 Var1 是你分球的规则,比如 “ V a r 1 = 1 Var1=1 Var1=1 放到一堆, V a r 1 = 0 Var1=0 Var1=0 放到另一堆”。
- Gini 指数告诉你,这两堆球中“混杂的程度”有多大:
- Gini 小:同一堆球的颜色越统一,分得越纯净。
- Gini 大:同一堆球的颜色越混杂,分得不纯净。
通过计算,发现 V a r 1 Var1 Var1 的整体 Gini 指数是 0.41667 0.41667 0.41667,说明它有一定的分割能力,可以帮助决策树进行分类。 下一步可以继续评估其他特征(比如 V a r 2 Var2 Var2),选择 Gini 指数最小的特征进行下一步分裂。
- 决策树的构建是逐步选择最佳的特征来分割数据,这里只是评估 V a r 1 Var1 Var1 的效果。后续构建树的过程中,也可以计算其他特征(例如 V a r 2 Var2 Var2)的 Gini 指数,比较不同特征的效果,选择最优特征作为分裂节点。
2.2 信息增益
信息增益(Information Gain
)表示通过某个属性分割数据后所获得的信息量。它衡量了该属性的重要性。在决策树构建过程中,关键在于找到最佳分裂节点以提高分类准确度,而信息增益正是用于识别返回最高信息增益的节点。这一过程基于熵(Entropy),熵表示系统的混乱程度:
- 熵越高,表示系统越混乱。
- 熵越低,表示系统越有序。当数据完全同质化时,熵为 0;当数据部分有序时,熵的值介于 0 和 1 之间。
熵与信息增益共同用于构建决策树,这种算法称为 ID3。
信息增益
信息增益表示通过某个属性分割数据后所获得的信息量,衡量了该属性的重要性。在决策树构建过程中,关键在于找到能够确保高准确率的最佳分裂节点,而信息增益则用于识别能返回最高信息增益的节点。这一过程通过熵(Entropy) 来计算,熵表示系统的混乱程度:
- 熵越高,系统越混乱。
- 熵越低,系统越有序。当样本完全同质时,熵为 0;而当样本部分有序时(例如 50% 有序),熵的值为 1。
熵与信息增益共同用于构建决策树,这种算法称为 ID3。
计算信息增益的步骤
(1)计算分割前输出属性的熵
使用以下公式计算:
E ( s ) = − ∑ i = 1 c p i log 2 p i E(s) = -\sum_{i=1}^c p_i \log_2 p_i E(s)=−i=1∑cpilog2pi
其中, p p p 表示成功的概率, c c c 表示类别的总数。
-
q = 1 − p q = 1 - p q=1−p 表示负类的概率
-
例如,10 个数据点中,5 个属于 True,5 个属于 False,则 c = 2 c = 2 c=2, p 1 p_1 p1 和 p 2 p_2 p2 均为 1 / 2 1/2 1/2。
(2)计算输入属性的熵
使用以下公式:
E ( T , x ) = ∑ c ∈ X P ( c ) E ( c ) E(T, x) = \sum_{c \in X} P(c)E(c) E(T,x)=c∈X∑P(c)E(c)
其中:
- T T T 是输出属性,
- X X X 是输入属性,
- P ( c ) P(c) P(c) 是数据点出现在 X X X 的概率,
- E ( c ) E(c) E(c) 是针对特定数据子集的熵。
假设某个输入属性(优先级)有两个可能的值 low 和 high:
- 对于 low,5 个数据点中有 2 个属于 True,3 个属于 False。
- 对于 high,剩下的 5 个数据点中有 4 个属于 True,1 个属于 False。
此时:
E ( T , x ) = 5 10 ∗ E ( 2 , 3 ) + 5 10 ∗ E ( 4 , 1 ) E(T, x) = \frac{5}{10} * E(2, 3) + \frac{5}{10} * E(4, 1) E(T,x)=105∗E(2,3)+105∗E(4,1)
- 在 E ( 2 , 3 ) E(2, 3) E(2,3) 中, p = 2 p = 2 p=2, q = 3 q = 3 q=3。
- 在 E ( 4 , 1 ) E(4, 1) E(4,1) 中, p = 4 p = 4 p=4, q = 1 q = 1 q=1。
(3)计算信息增益
使用输出属性的总熵与输入属性熵之差来计算:
G a i n ( T , X ) = E n t r o p y ( T ) − E n t r o p y ( T , X ) Gain(T, X) = Entropy(T) - Entropy(T, X) Gain(T,X)=Entropy(T)−Entropy(T,X)
选择信息增益最高的属性作为分裂节点。
(4)重复上述步骤
根据分裂结果继续执行步骤 1-4,将数据集根据分裂结果划分,直到所有数据被完全分类。
要点
- 当熵为 0 时,表示节点完全纯净,不需要进一步分裂。
- 只有熵大于 0(数据存在不纯度)的分支才需要继续分裂。
这张图对比了信息增益(熵)和Gini不纯度在概率 P P P 变化时的行为,主要用于解释决策树中两种常见的分裂评估指标的性质和趋势。
横轴 P P P 表示某个类别出现的概率,纵轴表示数据的不纯度。在二分类问题中,当 P P P 接近 0.5 0.5 0.5 时,不纯度最高,熵和Gini不纯度都达到最大值。相比之下,熵的数值稍高,更侧重于数据的不确定性;而 Gini 更直观地衡量类别的错误分类概率。
2.3 卡方检验
卡方检验(Chi-Square
)适用于目标变量是分类变量(例如成功-失败、高-低)的情况。该方法的核心思想是检验子节点与父节点之间的变化是否具有统计显著性。
卡方的数学公式为:
χ = ( A c t u a l − E x p e c t e d ) 2 E x p e c t e d \chi = \sqrt{\frac{(Actual - Expected)^2}{Expected}} χ=Expected(Actual−Expected)2
- 其中,实际值(Actual)与期望值(Expected)的标准化差异的平方和被用来衡量数据的变化程度。
- 卡方检验的优势之一是,它可以在单个节点上执行多次分割,从而提高分类的准确性和精度。
2.4 属性选择标准总结
- Gini 指数:通过衡量数据的不纯度(即分类错误的概率),选择 Gini 指数最小的属性进行分裂,适用于二分类和多分类问题,计算效率较高(如 CART 算法)。
- 信息增益:基于熵的概念,衡量分裂前后数据集纯度的提高程度,选择信息增益最高的属性进行分裂,常用于分类问题(如 ID3 算法)。
- 卡方检验:基于统计显著性原理,衡量属性与目标变量之间的关联程度,适用于目标变量为分类数据的情况,能同时处理多个类别,确保分裂结果的统计可靠性。
三者各有侧重:Gini 指数关注分类错误率,信息增益关注熵的变化,卡方检验强调统计学显著性。结合实际应用需求选择合适的标准,有助于构建更加高效和准确的决策树模型。
3 分类决策树代码实例
3.1 数据集
我们使用 Scikit-learn 提供的 Iris 数据集。目标是根据萼片和花瓣的长度和宽度对鸢尾花进行分类,分为三种类别:Setosa、Versicolor 和 Virginica。
下面来看一下这个数据集的属性:
属性名称 | 描述 | 示例值 |
---|---|---|
iris.DESCR | 数据集的完整描述信息 | 数据集的描述文件,提供详细信息 |
iris.data | 用于训练的数据,每个样本是一个包含 4 个特征的数组,共有 150 个样本 | [[5.1, 3.5, 1.4, 0.2], ...] |
iris.feature_names | 特征名称数组,包含 4 个特征 | ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] |
iris.filename | CSV 文件的名称 | iris.csv |
iris.target | 分类标签,每个训练样本对应一个标签,值为 0、1 或 2 | [0, 1, 2, ...] |
iris.target_names | 标签的含义数组,对应三个类别 | ['setosa', 'versicolor', 'virginica'] |
通过表格中的信息可以清楚地得出:
- 特征矩阵:
X = iris.data
(样本的 4 个特征)。 - 目标变量:
y = iris.target
(分类标签)。
3.2 代码实现
3.2.1 导入库和数据处理
导入库
import pandas as pd
import numpy as np
from sklearn import datasets
from sklearn import model_selection
from sklearn import tree
import graphviz
- pandas:用于数据的操作和分析。
- numpy:Python 的核心科学计算库,用于处理数组和矩阵操作。
- datasets:提供数据集支持,这里将使用
iris
和boston house prices
数据集。 - model_selection:包含
train_test_split()
方法,用于将数据划分为训练集和测试集。 - tree:用于实现决策树分类器和回归器。
- graphviz:用于将决策树导出为 Graphviz 格式,方便可视化树结构。
加载数据
iris = datasets.load_iris()
print('Dataset structure= ', dir(iris))
df = pd.DataFrame(iris.data, columns = iris.feature_names)
df['target'] = iris.target
df['flower_species'] = df.target.apply(lambda x : iris.target_names[x]) # Each value from 'target' is used as index to get corresponding value from 'target_names'
print('Unique target values=',df['target'].unique())
df.sample(5)
这里随机列了五个样本,输出如下:
数据划分
首先我们将四个特性和结果提取一下:
#Lets create feature matrix X and y labels
X = df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']]
y = df[['target']]
print('X shape=', X.shape) # X shape= (150, 4)
print('y shape=', y.shape) # y shape= (150, 1)
然后我们划分数据集,其中80%用来训练,20%用来测试:
X_train,X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size= 0.2, random_state= 1)
3.2.2 训练和测试模型
# 初始化决策树分类器,指定随机种子 random_state=1
cls = tree.DecisionTreeClassifier(random_state=1)
# 用训练数据集 X_train 和标签 y_train 拟合决策树模型
cls.fit(X_train ,y_train)
对于DecisionTreeClassifier
来说,它有很多默认参数:
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
max_depth=None, max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=1, splitter='best')
参数如下:
参数名称 | 默认值 | 描述 |
---|---|---|
ccp_alpha | 0.0 | 剪枝的复杂度参数,用于后剪枝(通过控制模型复杂度减少过拟合)。 |
class_weight | None | 用于调整类别权重(适用于不平衡数据集)。 |
criterion | 'gini' | 划分节点的标准(默认为 Gini 不纯度,可选择 'entropy' )。 |
max_depth | None | 树的最大深度,默认值 None 表示直到所有叶节点纯净或叶节点样本数少于 min_samples_split 时停止分裂。 |
max_features | None | 分裂时考虑的最大特征数。 |
max_leaf_nodes | None | 树的最大叶节点数,用于控制树的复杂度。 |
min_impurity_decrease | 0.0 | 最小的不纯度减少量,低于此值的分裂将被忽略。 |
min_samples_leaf | 1 | 每个叶节点的最小样本数。 |
min_samples_split | 2 | 分裂内部节点所需的最小样本数。 |
min_weight_fraction_leaf | 0.0 | 一个叶节点的最小加权样本比例。 |
presort | 'deprecated' | 已弃用参数,用于加速树的构建。 |
random_state | 1 | 随机种子,用于确保树的可重复性。 |
splitter | 'best' | 分裂策略,'best' 表示优先选择分裂效果最好的特征,'random' 表示随机选择分裂特征。 |
测试模型
现在我们使用测试数据对模型进行验证。下面预测测试数据中第 10 个、第 20 个和第 30 个样本的鸢尾花种类。
print('Actual value of species for 10th training example=',iris.target_names[y_test.iloc[10]][0])
print('Predicted value of species for 10th training example=', iris.target_names[cls.predict([X_test.iloc[10]])][0])
print('\nActual value of species for 20th training example=',iris.target_names[y_test.iloc[20]][0])
print('Predicted value of species for 20th training example=', iris.target_names[cls.predict([X_test.iloc[20]])][0])
print('\nActual value of species for 30th training example=',iris.target_names[y_test.iloc[29]][0])
print('Predicted value of species for 30th training example=', iris.target_names[cls.predict([X_test.iloc[29]])][0])
输出如下,可以看到预测全部正确。:
Actual value of species for 10th training example= versicolor
Predicted value of species for 10th training example= versicolor
Actual value of species for 20th training example= versicolor
Predicted value of species for 20th training example= versicolor
Actual value of species for 30th training example= virginica
Predicted value of species for 30th training example= virginica
现在来看一下模型得分:
cls.score(X_test, y_test)
输出:
0.9666666666666667
3.2.3 可视化
我们将使用 Scikit-learn 提供的 plot_tree()
函数来绘制决策树,并使用 export_graphviz
将决策树导出为 Graphviz 格式,并渲染为 iris_decision_tree.pdf
文件,然后进一步优化可视化效果,包括特征名、类别名、颜色填充和圆角,使决策树更加直观、清晰。
tree.plot_tree(cls)
dot_data = tree.export_graphviz(cls, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("iris_decision_tree")
dot_data = tree.export_graphviz(cls, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
输出:
决策树从根节点开始,根据每个特征的条件依次划分数据(例如 petal width (cm) ≤ 0.8
)。每个节点显示:
- Gini值:衡量节点的不纯度,值越小表示分类越纯。
- Samples:该节点包含的样本总数。
- Value:每个类别的样本数量(按顺序对应
setosa
、versicolor
、virginica
)。 - Class:当前节点的主要类别。
4 回归决策树代码实例
对于回归模型,我们将使用sklearn波士顿房价数据集,目的是根据可用数据预测房价。
4.1 数据集
数据集中的内容如下:
属性名称 | 描述 |
---|---|
boston.DESCR | 数据集的完整描述信息 |
boston.data | 用于训练的数据,共包含 13 个特征,目标变量通常为第 14 个属性(房价中位数)。共 506 个样本。 |
CRIM | 每个城镇的人均犯罪率 |
ZN | 占地面积超过 25,000 平方英尺的住宅用地比例 |
INDUS | 每个城镇的非零售商业用地比例 |
CHAS | 是否靠近查尔斯河(虚拟变量,=1 表示靠近河流,=0 表示不靠近) |
NOX | 氮氧化物浓度(每千万分之一) |
RM | 每栋住宅的平均房间数 |
AGE | 1940 年之前建造的自住单位比例 |
DIS | 到波士顿 5 个就业中心的加权距离 |
RAD | 辐射公路可达性指数 |
TAX | 每 10,000 美元的全值财产税率 |
PTRATIO | 每个城镇的学生-教师比率 |
B | 计算公式:1000(Bk - 0.63)^2,其中 Bk 是城镇中黑人比例 |
LSTAT | 低收入人口的百分比 |
MEDV | 自住房的房价中位数(以千美元为单位) |
boston.feature_names | 包含所有 13 个特征的数组:['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'] |
boston.filename | 数据集的 CSV 文件名 |
boston.target | 房价的中位数(以千美元为单位) |
结论:
- 特征矩阵:
X = boston.data
- 目标变量:
y = boston.target
4.2 代码实现
4.2.1 数据处理
加载数据集
boston = datasets.load_boston()
print('Dataset structure= ', dir(boston))
df = pd.DataFrame(boston.data, columns = boston.feature_names)
df['target'] = boston.target
df.sample(5)
输出:
数据划分
#Lets create feature matrix X and y labels
X = df[['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT']]
y = df[['target']]
print('X shape=', X.shape) # X shape= (506, 13)
print('y shape=', y.shape) # y shape= (506, 1)
同样地,我们将80%的数据集用于训练,20%的数据集用于测试:
X_train,X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size= 0.2, random_state= 1)
4.2.2 训练和测试模型
这里使用DecisionTreeRegressor
创建回归模型:
dtr = tree.DecisionTreeRegressor(max_depth= 3,random_state= 1)
dtr.fit(X_train ,y_train)
为了保持决策树的简洁性,此处设置 max_depth = 3
。需要注意的是,回归模型的默认准则是 'mse'
(均方误差)。mse
等价于以方差减少作为特征选择的准则,同时通过使用每个终端节点的均值最小化 L2 损失。所有默认参数如下:
DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=3,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=1, splitter='best')
测试模型
这里我们输出对比一下实际的房价和我们预测的房价
predicted_price= pd.DataFrame(dtr.predict(X_test), columns=['Predicted Price'])
actual_price = pd.DataFrame(y_test, columns=['target'])
actual_price = actual_price.reset_index(drop=True) # Drop the index so that we can concat it, to create new dataframe
df_actual_vs_predicted = pd.concat([actual_price,predicted_price],axis =1)
df_actual_vs_predicted.T
输出如下:
模型得分
现在通过 score()
方法计算模型的决定系数
R
2
R^2
R2。测试数据测试模型的得分:
dtr.score(X_test, y_test)
输出:
0.6647121948539862
这表示模型能解释 66.47% 的测试数据方差。
4.2.3 可视化
和前面的分类模型一样,我们可视化一下这个决策树:
tree.plot_tree(dtr)
dot_data = tree.export_graphviz(dtr, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("boston_decision_tree")
dot_data = tree.export_graphviz(dtr, out_file=None,
feature_names=boston.feature_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
输出如下:
这是一个回归决策树,它预测的是房价的中位数(MEDV
),输出是一个具体的数值。决策树的每个节点中的参数如下:
- 分裂条件:每个非叶节点有一个条件(例如
LSTAT ≤ 9.725
),表示如何划分数据。 - 均方误差(
mse
):衡量当前节点目标值的分布,mse
越小表示节点内样本值越接近当前value
。 - 样本数量(
samples
):当前节点包含的样本数量。 - 预测值(
value
):当前节点中所有样本目标值的平均值,表示模型的预测输出。
在回归决策树中,叶节点的 value
是该节点中所有样本目标值(房价中位数)的均值。也就是说:
- 当树分裂到某一节点时,会根据条件将样本划分到左子节点或右子节点。
- 当到达叶节点时,该节点的
value
是所有分裂到该节点的样本的目标值的平均值。
公式如下:
value
=
1
n
∑
i
=
1
n
y
i
\text{value} = \frac{1}{n} \sum_{i=1}^{n} y_i
value=n1i=1∑nyi
- n n n:叶节点中的样本数量。
- y i y_i yi:每个样本的目标值(房价中位数)。
5 总结
本文系统地介绍了决策树的基础概念、构建过程以及常用的分裂准则。决策树作为一种直观且易于解释的算法,广泛应用于分类和回归任务,如客户行为分析、医疗诊断以及销售预测等。
它的优点是直观易解释,能处理非线性关系,对特征缩放不敏感;缺点是容易过拟合、对数据波动敏感,需通过剪枝或设置参数(如深度)控制复杂度。回归决策树更适合处理少量特征的回归问题,但在复杂场景中可能需要结合集成方法(如随机森林)提升性能。