决策树的底层原理
决策树是一种常用的分类和回归算法,其基本原理是通过一系列的简单决策,将数据集划分为多个子集,从而实现分类。决策树的核心思想是通过树形结构表示决策过程,节点代表特征,边代表决策,叶子节点代表类别。
下面是一个决策树例子(用挑选好西瓜来举例,最终结果为判断是好瓜还是坏瓜):
1.1 决策树的基本结构基本思想
- 根节点:树的起点,表示整个数据集。
- 内部节点/分支节点:表示根据某一特征进行的决策,根据此特征条件将数据集划分为更小的子集。
- 叶子节点:表示最终的分类结果(如“推荐动作片”或“不推荐动作片”)或回归值(预测值)。
1.2 决策树的流程
- 从数据集开始,选择一个最能区分样本的属性(例如“年龄”)。
- 根据该属性的取值,将数据分成多个子集。
- 在每个子集上重复上述步骤,直到满足停止条件(如数据不可再分,或者已经达到期望的预测准确度)。
- 将每个分组的最终结果作为叶节点。
决策树的构建
决策树的构建过程通常采用递归的方式,核心步骤包括特征选择、数据划分和停止条件。
以“推荐电影”为例,假设我们有以下数据集:
年龄段 | 是否喜欢动作片 | 是否喜欢喜剧片 | 是否喜欢动画片 | 推荐电影类型 |
---|---|---|---|---|
年轻人 | 是 | 是 | 是 | 动作片 |
年轻人 | 是 | 否 | 是 | 动作片 |
中年人 | 否 | 是 | 否 | 喜剧片 |
中年人 | 否 | 否 | 是 | 动画片 |
老年人 | 否 | 是 | 否 | 喜剧片 |
目标是通过决策树预测“推荐电影类型”。
2.1 第一步:选择分裂的属性
决策树的核心是找到一个最优的属性来分裂数据集。例如,在我们的例子中,候选属性有“年龄段”、“是否喜欢动作片”、“是否喜欢喜剧片”和“是否喜欢动画片”。
如何选择最优属性?特征选择
在每个节点上,需要选择一个特征来划分数据集,选择的常用标准是信息增益或基尼指数:
-
信息增益:
- 定义:衡量一个属性划分数据后,数据的“纯度”提高了多少,基于香农信息论,信息增益是划分前后信息的不确定性减少量。
- 纯度越高,越容易预测目标(推荐电影类型)。
公式:
信息增益=原始熵−划分后的熵熵的计算公式:
其中参数含义:
H(S):数据集 S 的信息熵,表示划分前数据集的纯度或不确定性。
n 为分类的类别数
是数据集中属于第 i 类的样本所占的比例。
其中:
信息增益的公式为:
参数含义
-
H(S):数据集 S 的信息熵,表示划分前数据集的纯度或不确定性
-
S:当前数据集(未划分之前的整体数据集)。
-
:特征划分后生成的第 i 个子数据集。
-
∣S∣:数据集 S 中样本的总数。
-
:第 i 个子数据集中的样本数。
-
:子数据集
的信息熵。
-
:第 i 个子数据集样本数占总样本数的比例,表示权重。
公式解读
- 第一项 H(S):表示数据集未划分前的混乱程度(信息熵)。
- 第二项
:表示特征划分后各子数据集的不确定性总和。
- 通过权重
平均计算每个子数据集的不确定性。
- 通过权重
- 信息增益:划分前的不确定性减去划分后的不确定性,即使用某个特征划分数据所带来的不确定性减少量。
示例
假设数据集 S 有 10 个样本,其中 6 个属于类 A,4 个属于类 B:
-
未划分时的信息熵:
= -0.5394
= -0.5487
S1 中有 4 个样本,均为类 A,没有类 B。我们按以下步骤计算:
-
确定类别分布:
- 类 A 的样本数为 4,比例
=4/4=1。
- 类 B 的样本数为 0,比例
=0/4=0。
- 类 A 的样本数为 4,比例
-
代入熵公式:
(因为
)
(因为
=0)
如果
中样本属于同一类(如都为类 A),则
=0,表示完全纯度。
如果 中样本属于多类(如 A 和 B 都有),则
会大于 0,表示不确定性增加
-
-
划分后,数据分为两组:
:包含 4 个样本,均为类 A。
:包含 6 个样本,3 个为类 A,3 个为类 B。
则计算划分后熵的加权和:
-
信息增益:
= -0.5394 + 0.3292 = -0.2102
通过比较不同特征的信息增益,选择信息增益最大的特征进行划分。
示例结束
2. 信息增益率:为了解决信息增益偏向于选择取值较多的特征的问题,信息增益率在信息增益的基础上进行归一化:
-
基尼指数:主要用于 CART(Classification and Regression Trees)算法,计算某个特征的基尼指数,公式为:
其中,
为类
在数据集 D 中的比例。
2.2 第二步:分裂数据集
根据计算的熵或基尼指数,选择最优属性来分裂数据集。假设通过计算发现“年龄段”是最优属性。
分裂结果:
- 年轻人:{1, 2}
- 中年人:{3, 4}
- 老年人:{5}
2.3 第三步:数据划分,递归构建子树
根据选择的特征,将数据集划分为多个子集。对于连续特征,通常会选取一个阈值,将数据集分为小于阈值和大于阈值两部分;对于分类特征,则根据每个取值进行划分。
对子集递归重复上述过程。例如:
- 对“年轻人”子集,再看“是否喜欢动作片”能否进一步提高纯度。
- 如果某个子集的样本已经属于同一类别(如全是“推荐动作片”),则停止分裂。
2.4 停止条件
- 所有样本都属于同一类别(纯度100%)。
- 达到最大深度,没有更多的属性可分裂。
- 样本数量太少,无法继续分裂。
- 节点样本数低于某一阈值。
- 信息增益或基尼指数的减少低于某一阈值。
决策树的剪枝
为了解决过拟合问题,决策树通常会进行剪枝,分为预剪枝和后剪枝:
- 预剪枝:在树的构建过程中,实时评估当前分裂的效果,决定是否继续分裂。
- 后剪枝:先构建完整的树,再从叶子节点向上进行剪枝,去掉一些不必要的分支。
决策树的算法
决策树的构建算法主要有 ID3、C4.5、CART 等。
- ID3:使用信息增益作为特征选择的标准,适用于分类任务。
- C4.5:改进了 ID3,使用信息增益率作为标准,支持连续特征和缺失值。
- CART:使用基尼指数进行特征选择,支持分类和回归任务。
决策树的优缺点
优点:
- 直观易懂:决策树模型易于理解和可视化。
- 无需特征缩放:对特征的缩放和归一化不敏感。
- 适用性广,可处理非线性关系:能够自动捕捉复杂的非线性特征。可以处理分类和回归问题,且对数据类型没有强要求。
缺点:
- 过拟合:决策树容易在训练数据上过拟合,尤其是深度较大的树。
- 不稳定性,对噪声敏感:对训练数据的微小变化敏感,少量噪声数据可能显著改变树的结构,导致树的结构有较大差异。
- 偏向于某些特征:使用信息增益时,可能偏向于选择取值较多的特征。
- 不能很好处理连续变量:需要将连续变量离散化。
决策树的具体实现
3.1 分类与回归树
决策树主要有两种常见形式:
- 分类树:用于分类任务(如推荐电影类型)。
- 回归树:用于数值预测任务(如预测房价)。
3.2 伪代码
构建决策树的核心流程如下:
- 输入:数据集 S,候选属性集合 A,停止条件。
- 如果 S 中样本属于同一类别,则创建叶节点。
- 如果 A 为空,则创建叶节点,类别为多数样本类别。
- 选择最优划分属性
。
- 根据
的取值,将 S 划分为多个子集。
- 对每个子集递归调用,构建子树。
在 Python 中,使用 scikit-learn
库可以非常方便地实现决策树。以下是一个基本的实现示例:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
import matplotlib.pyplot as plt
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 构建决策树模型
clf = DecisionTreeClassifier(criterion='gini', max_depth=3)
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 可视化决策树
plt.figure(figsize=(12, 8))
tree.plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.show()
以下是基于 Python 的决策树实现(使用信息增益):
import numpy as np
from collections import Counter
# 计算熵
def entropy(data):
total = len(data)
label_counts = Counter(data)
return -sum((count / total) * np.log2(count / total) for count in label_counts.values())
# 计算信息增益
def information_gain(data, attribute_index, target_index):
total_entropy = entropy([row[target_index] for row in data])
values = set(row[attribute_index] for row in data)
weighted_entropy = 0
for value in values:
subset = [row for row in data if row[attribute_index] == value]
weighted_entropy += len(subset) / len(data) * entropy([row[target_index] for row in subset])
return total_entropy - weighted_entropy
# 构建决策树
def build_tree(data, attributes, target_index):
target_values = [row[target_index] for row in data]
if len(set(target_values)) == 1:
return target_values[0] # 纯叶节点
if not attributes:
return Counter(target_values).most_common(1)[0][0] # 多数类叶节点
# 找到最优划分属性
gains = [information_gain(data, attr_index, target_index) for attr_index in attributes]
best_attr_index = attributes[np.argmax(gains)]
tree = {best_attr_index: {}}
for value in set(row[best_attr_index] for row in data):
subset = [row for row in data if row[best_attr_index] == value]
subtree = build_tree(subset, [attr for attr in attributes if attr != best_attr_index], target_index)
tree[best_attr_index][value] = subtree
return tree
# 示例数据集
data = [
['年轻人', '是', '是', '是', '动作片'],
['年轻人', '是', '否', '是', '动作片'],
['中年人', '否', '是', '否', '喜剧片'],
['中年人', '否', '否', '是', '动画片'],
['老年人', '否', '是', '否', '喜剧片']
]
attributes = [0, 1, 2, 3] # 属性索引(0: 年龄段, 1: 动作片, 2: 喜剧片, 3: 动画片)
target_index = 4 # 目标列索引
# 构建决策树
decision_tree = build_tree(data, attributes, target_index)
print("决策树结构:", decision_tree)
决策树的应用
决策树广泛应用于金融、医疗、市场分析等多个领域,如:
- 信用评分:评估客户的信用风险。
- 医学诊断:帮助医生进行疾病预测和诊断。
- 客户分类:根据客户特征进行市场细分。
决策树改进方法
- 剪枝:减少过拟合问题,包括预剪枝和后剪枝。
- 随机森林:集成多棵决策树,降低单棵树的偏差。
- 梯度提升树(GBDT):利用多棵决策树提升性能。
总结
决策树是一种强大的分类和回归模型,通过树形结构进行决策。其构建过程包括特征选择、数据划分、剪枝等步骤,易于理解和实现,但需注意过拟合和模型稳定性的问题。在实际应用中,可以根据具体场景选择合适的决策树算法和参数设置。