机器学习 第四章 决策树
4.1 基本流程
1. 决策树定义
- 树形结构:根节点→内部节点→叶节点
- 核心功能:通过属性判断对样本分类/回归
2. 构建三步骤
步骤 | 操作 | 关键方法 |
---|---|---|
特征选择 | 选择最优划分属性 | 信息增益/基尼指数/增益率 |
节点分裂 | 按属性值分割数据 | 离散属性分支/连续属性二分法 |
递归终止 | 停止分裂并标记叶节点 | 样本全同类/无属性可用/样本过少 |
3. 伪代码描述
def build_tree(data):
if 满足停止条件:
return 叶节点(多数类别)
最优属性 = 选择最佳划分属性(data)
树节点 = 创建节点(最优属性)
for 每个属性取值:
子数据 = 按取值划分数据集(data)
树节点.add_child(build_tree(子数据))
return 树节点
4.2 划分选择
1. 划分指标对比
指标 | 算法 | 公式 | 优化目标 |
---|---|---|---|
信息增益 | ID3 | Gain=Ent(D)-Σ(∣Dᵛ∣/∣D∣*Ent(Dᵛ)) | 最大化增益 |
信息增益率 | C4.5 | Gain_ratio=Gain/IV(a) | 最大化增益率 |
基尼指数 | CART | Gini_index=Σ(∣Dᵛ∣/∣D∣*Gini(Dᵛ)) | 最小化基尼指数 |
2. 连续值处理
- 二分法步骤:
- 排序连续属性值
- 生成候选划分点(相邻值中点)
- 计算各划分点指标,选择最优
示例代码:
sorted_values = np.sort(X[:, feature])
split_points = (sorted_values[:-1] + sorted_values[1:]) / 2
best_point = max(split_points, key=lambda t: calc_gain(X, y, t))
4.3 剪枝处理
1. 剪枝类型对比
类型 | 操作时机 | 典型方法 | 核心参数 |
---|---|---|---|
预剪枝 | 树构建过程中 | 限制最大深度/最小样本分裂数 | max_depth , min_samples_split |
后剪枝 | 树构建完成后 | REP/PEP/CCP | ccp_alpha |
2. CCP剪枝步骤
- 计算子树序列的复杂度参数 α \alpha α
- 通过交叉验证选择最优 α \alpha α
- 剪去对应子树
Scikit-learn实现:
from sklearn.tree import DecisionTreeClassifier
path = DecisionTreeClassifier.cost_complexity_pruning_path(X, y)
ccp_alphas = path.ccp_alphas # 候选α值列表
4.4 连续与缺失值
1. 连续值处理(二分法)
步骤:
- 对属性值排序(如含糖率:[0.15, 0.20, 0.25])
- 生成候选划分点(0.175, 0.225)
- 选择最优划分点:
Gain ( D , a , t ) = max t ∈ T a Gain ( D , a , t ) \text{Gain}(D, a, t) = \max_{t \in T_a} \text{Gain}(D, a, t) Gain(D,a,t)=t∈TamaxGain(D,a,t)
2. 缺失值处理(C4.5方法)
- 权重分配:
- 初始权重 w x = 1 w_x=1 wx=1
- 划分时按无缺失样本比例分配权重
- 增益计算修正:
ρ = ∑ x ∈ D ~ w x ∑ x ∈ D w x \rho = \frac{\sum_{x \in \tilde{D}} w_x}{\sum_{x \in D} w_x} ρ=∑x∈Dwx∑x∈D~wx
Gain ( D , a ) = ρ × Gain ( D ~ , a ) \text{Gain}(D, a) = \rho \times \text{Gain}(\tilde{D}, a) Gain(D,a)=ρ×Gain(D~,a)
4.5 多变量决策树
1. 核心改进
- 划分条件:使用多个属性的线性组合(如 0.6 X 1 + 0.4 X 2 ≤ 0.5 0.6X_1 + 0.4X_2 \leq 0.5 0.6X1+0.4X2≤0.5)
- 决策边界:支持斜线或曲线分割,解决单变量决策树轴平行分割的局限性
2. 构建流程
步骤 | 操作 |
---|---|
候选组合生成 | 随机生成或通过LDA计算权重向量 |
最优组合选择 | 最大化信息增益/最小化基尼指数 |
递归分裂 | 与传统决策树相同,生成子树 |
伪代码片段:
class MultivariateNode:
def __init__(self, weights, threshold):
self.weights = weights # 线性组合权重
self.threshold = threshold
self.left = None # 左子树(满足条件)
self.right = None # 右子树(不满足条件)
def find_best_combination(data):
# 通过启发式搜索或优化算法寻找最优权重和阈值
return best_weights, best_threshold