在前面的文章中,我们介绍了 SVM、随机森林、神经网络等模型。
今天要回到一个经典而直观的模型:决策树(Decision Tree)。
🧩 决策树的基本思想
决策树是一种 树形结构的分类与回归方法:
- 每个 节点 表示一个特征条件(例如“第5个波段 > 0.23”)。
- 每个 分支 表示决策结果(是/否)。
- 直到叶子节点,给出分类结果(类别ID)。
优点:
- 结构简单,结果易解释;
- 训练速度快,能处理非线性关系;
- 不需要大量预处理。
缺点:
- 单棵树容易过拟合,对噪声敏感;
- 一般在实际应用中更常作为 集成方法(随机森林、梯度提升) 的基础模型。
💻 代码示例:KSC 数据 + 决策树分类
# -*- coding: utf-8 -*-
"""
Sklearn案例⑰:决策树分类(KSC 真实数据)
- 无泄露预处理:仅在训练像素上 fit 标准化与 PCA
- 评估:OA/Kappa/分类报告 + 数字标注混淆矩阵
- 整图预测:逐像素分类可视化(离散色标)
- 决策树结构:仅展示前1层,放大字体、按比例显示
"""
import os
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.colors import ListedColormap, BoundaryNorm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import (confusion_matrix, classification_report,
accuracy_score, cohen_kappa_score)
# ===== 中文显示 =====
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
# ===== 参数区(仅需修改 DATA_DIR)=====
DATA_DIR = r"your_path" # ←← 修改为你的 KSC 数据路径(包含 KSC.mat / KSC_gt.mat)
PCA_DIM = 30
TRAIN_SIZE = 0.3
SEED = 42
# ===== 1) 读取 KSC 数据(仅取有标签像素)=====
X_cube = sio.loadmat(os.path.join(DATA_DIR, "KSC.mat"))["KSC"].astype(np.float32) # (H,W,B)
Y_map = sio.loadmat(os.path.join(DATA_DIR, "KSC_gt.mat"))["KSC_gt"].astype(int) # (H,W)
h, w, b = X_cube.shape
coords = np.argwhere(Y_map != 0) # 有标签像素坐标
X_all = X_cube[coords[:, 0], coords[:, 1]] # (N,B)
y_all = Y_map[coords[:, 0], coords[:, 1]] - 1
num_classes = int(y_all.max() + 1)
print(f"[INFO] 有标签像素: {len(y_all)}, 类别数: {num_classes}")
# ===== 2) 分层划分(仅在有标签像素上)=====
X_tr, X_te, y_tr, y_te = train_test_split(
X_all, y_all,
train_size=TRAIN_SIZE,
stratify=y_all,
random_state=SEED
)
# ===== 3) 无泄露预处理:仅用训练像素拟合 =====
scaler = StandardScaler().fit(X_tr)
pca = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(X_tr))
X_train = pca.transform(scaler.transform(X_tr))
X_test = pca.transform(scaler.transform(X_te))
# ===== 4) 决策树模型 =====
clf = DecisionTreeClassifier(
criterion="gini", # 或 "entropy"
max_depth=20, # 适度限制深度防过拟合
min_samples_split=5,
random_state=SEED
)
clf.fit(X_train, y_tr)
y_pred = clf.predict(X_test)
# ===== 5) 评估与报告 =====
oa = accuracy_score(y_te, y_pred)
kappa = cohen_kappa_score(y_te, y_pred)
cm = confusion_matrix(y_te, y_pred, labels=np.arange(num_classes))
cmn = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1)
print(f"OA = {oa*100:.2f}% Kappa = {kappa:.4f}")
print(classification_report(y_te, y_pred, digits=4, zero_division=0))
# ===== 6) 混淆矩阵(计数 & 归一化,含数字)=====
fig, axes = plt.subplots(1, 2, figsize=(11, 4.2), constrained_layout=True)
im0 = axes[0].imshow(cm, cmap=plt.cm.YlGnBu, interpolation='nearest', alpha=0.9)
axes[0].set_title("混淆矩阵(计数)")
axes[0].set_xlabel("预测"); axes[0].set_ylabel("真实")
for i in range(num_classes):
for j in range(num_classes):
axes[0].text(j, i, str(cm[i, j]),
ha='center', va='center',
color='black' if cm[i, j] < cm.max()/2 else 'white', fontsize=9)
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
im1 = axes[1].imshow(cmn, vmin=0, vmax=1, cmap=plt.cm.OrRd, interpolation='nearest', alpha=0.9)
axes[1].set_title("混淆矩阵(归一化)")
axes[1].set_xlabel("预测"); axes[1].set_ylabel("真实")
for i in range(num_classes):
for j in range(num_classes):
axes[1].text(j, i, f"{cmn[i, j]*100:.1f}%",
ha='center', va='center', color='black', fontsize=9)
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
plt.show()
# ===== 7) 整图像素级预测(逐像素分类)=====
X_flat = X_cube.reshape(-1, b)
X_flat_pca = pca.transform(scaler.transform(X_flat))
pred_map = clf.predict(X_flat_pca).reshape(h, w) + 1 # 转 1..C 便于显示
# 离散色表(1..C),背景本步骤不单独处理
base_cmap = plt.get_cmap('tab20')
colors = [base_cmap(i % 20) for i in range(num_classes)]
cmap = ListedColormap(colors)
boundaries = np.arange(0.5, num_classes + 1.5, 1)
norm = BoundaryNorm(boundaries, cmap.N)
plt.figure(figsize=(8.6, 6.4))
im = plt.imshow(pred_map, cmap=cmap, norm=norm, interpolation='nearest')
plt.title("KSC 整图预测结果(决策树)")
plt.axis('off')
cbar = plt.colorbar(im, boundaries=boundaries,
ticks=np.arange(1, num_classes + 1, max(1, num_classes // 12)),
fraction=0.046, pad=0.04)
cbar.set_label("类别ID", rotation=90)
plt.show()
# ===== 8) 决策树结构可视化(前两层,放大字体、按比例显示)=====
plt.figure(figsize=(16, 8))
plot_tree(
clf,
filled=True,
feature_names=[f"PC{i+1}" for i in range(PCA_DIM)],
class_names=[str(i+1) for i in range(num_classes)],
max_depth=1, # 仅展示前两层
fontsize=10, # 放大字体
proportion=True # 节点以类别比例显示,信息更简洁
)
plt.title("决策树前1层结构示意(按比例显示)", fontsize=14)
plt.show()
🔍 结果解读



-
性能指标
- 控制台输出:OA、Kappa、精确率、召回率和 F1。
- 混淆矩阵:可以直观看出哪些类别容易混淆。
-
整图预测
- 每个像素都会被分配一个类别,直观呈现分类效果。
- 由于单棵树容易过拟合,分类图可能存在块状噪声。
-
树结构可视化
- 决策树可以直接画出来(示例只画前两层),展示分类规则,非常适合教学与解释。
✅ 总结
-
决策树是一个直观、易解释的模型,在遥感分类中可作为 入门基线。
-
它的主要超参数包括:
criterion(划分标准):gini或entropy;max_depth(最大深度):防止过拟合的重要参数;min_samples_split(最小分裂样本数):控制节点继续划分的条件。
-
实际应用中,单棵树的泛化能力有限,更推荐使用 集成方法(随机森林、梯度提升树)。
欢迎大家关注下方公众号获取更多内容!
1163

被折叠的 条评论
为什么被折叠?



