【解决】决策树过拟合怎么办

本文介绍了决策树模型及其在过拟合问题上的处理,重点讨论了预剪枝和后剪枝两种策略,以及如何在sklearn中通过参数调整实现这些方法。同时提到了减少特征维度和增加样本量作为防止过拟合的手段。

目录

一、决策树简介

1.1 决策树模型介绍

1.2 决策树例子

二、决策树过拟合时的解决方案

2.1 决策树过拟合的解决方案概述

三、关于决策树剪枝

3.1  决策树的预剪枝

3.2 决策树的后剪枝



决策树是一种容易过拟合的模型,在使用决策树时,往往会发生过拟合,这时候,就需要采取相关策略解决决策树的过拟合问题。本文讲解决策树过拟合时,可以采取的策略。

一、决策树简介

1.1 决策树模型介绍

决策树模型是一种非参数的有监督学习方法,它通过树形结构(可以是二叉树或非二叉树)表示决策过程。在决策树中,每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。决策树模型可以用于分类和回归问题。

下面是一个决策树模型示例:

1.2 决策树例子

在python中可以通过sklearn来构建一个决策树

from sklearn.datasets import load_iris
from sklearn import tree

#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据

#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier()         # sk-learn的决策树模型
clf = clf.fit(iris.data, iris.target)        # 用数据训练树模型构建()
r = tree.export_text(clf, feature_names=iris['feature_names'])

#---------------模型预测结果------------------------
text_x = iris.data[[0,1,50,51,100,101], :]
pred_target_prob = clf.predict_proba(text_x)        # 预测类别概率
pred_target = clf.predict(text_x)              # 预测类别

#---------------打印结果---------------------------
print("\n===模型======")
print(r)
print("\n===测试数据:=====")
print(text_x)
print("\n===预测所属类别概率:=====")
print(pred_target_prob)
print("\n===预测所属类别:======")
print(pred_target)

运行后显示如下结果:

如果需要将决策树可视化,可参考文章:

怎么画出决策树-两种决策树的可视化方法-优快云博客

如果不想通过sklearn软件包,而自己构建一个决策树,可参考

老饼讲解|决策树建模完整代码

二、决策树过拟合时的解决方案

2.1 决策树过拟合的解决方案概述

决策树过拟合的解决方法有以下几种:

剪枝:剪枝是一种常用的解决决策树过拟合问题的方法。它通过去掉一些不必要的节点来降低模型复杂度,从而避免过拟合。具体来说,剪枝分为预剪枝和后剪枝两种方式。

预剪枝是在构建决策树时,在每个节点处判断是否应该继续分裂。如果当前节点无法提高模型性能,则停止分裂,将当前节点标记为叶子节点。这样可以有效地减少模型复杂度,避免过拟合。

后剪枝则是在构建完整个决策树之后,对树进行修剪。具体来说,它通过递归地考虑每个非叶子节点是否可以被替换成叶子节点来达到降低模型复杂度、避免过拟合的目的。
增加样本量:过拟合通常是由于训练数据量太少导致的。因此,增加样本量可以有效地缓解过拟合问题。在实际应用中,可以通过收集更多的数据来增加样本量。


减少特征维度:过拟合可能是因为模型学习了太多的无关特征,因此可以考虑减少特征的维度,从而降低模型的复杂度。


综上所述,解决决策树过拟合问题可以从剪枝、增加样本量、减少特征维度等方面入手。

三、关于决策树剪枝

3.1  决策树的预剪枝

决策树的预剪枝是在构建决策树的过程中,对每个结点在划分前进行估计,如果当前结点的划分不能带来决策树模型泛化性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。这种方法是为了简化决策树模型,避免过拟合。预剪枝的方法包括设定阈值和到达某个结点的实例个数小于某个阈值时停止树的生长。

常用的预剪枝有如下:

1.设定每个叶子节点的最小样本数,以避免某个特征类别只适用于极少数的样本。
2.设定每个节点的最小样本数,从根节点开始避免过度拟合。
3.设定树的最大深度,避免无限往下划分。
4.设定叶子节点的最大数量,避免出现无限多次划分类别。
5.设定评估分割数据是的最大特征数量,避免每次都考虑所有特征为求“最佳”,而采取随机选择的方式避免过度拟合。


预剪枝是树构建过程,达到一定条件就停止生长,在sklearn中,实际就是调参,通过设置树的生长参数,来达到预剪枝的效果。

在skelarn中,与预剪枝相关的参数有:

min_samples_leaf           :叶子节点最小样本数       
min_samples_split          :节点分枝最小样本个数     
max_depth                  :树分枝的最大深度            
min_weight_fraction_leaf   :叶子节点最小权重和         
min_impurity_decrease      :节点分枝最小纯度增长量   
max_leaf_nodes             :最大叶子节点数        

3.2 决策树的后剪枝

决策树的剪枝,除了预剪枝,还有后剪枝。后剪枝是在决策树生长完成之后,对树进行剪枝,得到简化版的决策树。后剪枝通过对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。后剪枝一般采用的是CCP后剪枝。

CCP的后剪枝公式如下:

在sklearn中,就是通过设置alpha参数来达到后剪枝的效果。

后剪枝的具体例子可以参考《老饼讲解|sklearn决策树后剪枝


写文不易,点赞收藏吧~!

### 三、决策树过拟合的原因 决策树过拟合的主要原因在于其模型复杂度与训练数据之间的不匹配。当决策树的结构过于复杂时,它不仅会学习到数据中的真实模式,还会捕捉到训练数据中的噪声和异常值,导致在新数据上的泛化能力下降。具体原因包括以下几点: 1. **训练数据集较小**:当训练数据集较小时,模型难以学习到数据的整体分布,容易在训练数据上过度拟合。 2. **决策树过于复杂**:决策树的生长过程是基于递归分裂的策略,如果不对树的深度或节点数量进行限制,树可能会无限生长,导致每个叶子节点只包含少量甚至单一的样本,从而过度拟合训练数据。 3. **特征选择不当**:如果在构建决策树时选择了不相关的或冗余的特征,模型可能会过度依赖某些特定特征,从而影响模型的泛化能力。 ### 三、解决决策树过拟合的方法 为了解决决策树过拟合问题,通常采用以下几种方法: 1. **剪枝技术**:剪枝是解决决策树过拟合的关键手段,分为预剪枝和后剪枝两种方式。 - **预剪枝**是在构建决策树的过程中,对每个节点在划分前进行估计,如果当前节点的划分不能带来模型泛化性能的提升,则不对当前节点进行划分,并将当前节点标记为叶子节点。这种方法可以有效简化模型,避免过拟合[^4]。 - **后剪枝**则是在决策树完全生成之后,通过自底向上的方式对树进行剪枝,将某些子树替换为叶子节点,以提高模型的泛化能力。 2. **限制树的深度**:通过设定树的最大深度,可以防止树过度生长。通常,较浅的树具有更强的泛化能力,而较深的树则容易过拟合训练数据。 3. **设置叶子节点的最小样本数**:在构建决策树时,可以设定每个叶子节点至少需要包含的样本数。如果某个分裂操作会导致叶子节点的样本数低于设定的阈值,则不进行该分裂。这种方法可以避免树过于复杂,减少过拟合的风险。 4. **使用集成方法**:集成方法如随机森林和梯度提升树(GBDT)可以通过组合多个决策树的结果来提高模型的泛化能力。随机森林通过引入随机性(如随机选择样本和特征)来降低模型的方差,从而减少过拟合的可能性;而GBDT则通过逐步修正前一个模型的错误来提高模型的准确性,同时也可以通过早停等技术防止过拟合。 5. **正则化**:在决策树的构建过程中,可以引入正则化项来惩罚复杂的模型。例如,在选择分裂特征时,除了考虑信息增益或基尼不纯度等指标外,还可以引入对特征数量或树深度的惩罚项,从而控制模型的复杂度。 6. **交叉验证**:通过交叉验证可以评估模型在不同数据子集上的表现,帮助选择最优的超参数(如树的深度、叶子节点的最小样本数等),从而避免过拟合。 ### 示例代码 以下是一个使用Scikit-learn库构建决策树并应用预剪枝的示例代码: ```python from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score # 加载数据集 iris = load_iris() X, y = iris.data, iris.target # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 构建决策树分类器,并设置预剪枝参数 # max_depth 控制树的最大深度 # min_samples_split 设置内部节点再划分所需最小样本数 # min_samples_leaf 设置叶子节点最少样本数 clf = DecisionTreeClassifier(max_depth=3, min_samples_split=5, min_samples_leaf=2, random_state=42) # 训练模型 clf.fit(X_train, y_train) # 预测 y_pred = clf.predict(X_test) # 评估模型 accuracy = accuracy_score(y_test, y_pred) print(f"Model Accuracy: {accuracy}") ``` ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值