遥感&机器学习入门实战教程|Sklearn案例⑰:决策树与遥感分类(tree 模块)

部署运行你感兴趣的模型镜像

在前面的文章中,我们介绍了 SVM、随机森林、神经网络等模型。
今天要回到一个经典而直观的模型:决策树(Decision Tree)

🧩 决策树的基本思想

决策树是一种 树形结构的分类与回归方法

  • 每个 节点 表示一个特征条件(例如“第5个波段 > 0.23”)。
  • 每个 分支 表示决策结果(是/否)。
  • 直到叶子节点,给出分类结果(类别ID)。

优点:

  • 结构简单,结果易解释;
  • 训练速度快,能处理非线性关系;
  • 不需要大量预处理。

缺点:

  • 单棵树容易过拟合,对噪声敏感;
  • 一般在实际应用中更常作为 集成方法(随机森林、梯度提升) 的基础模型。

💻 代码示例:KSC 数据 + 决策树分类

# -*- coding: utf-8 -*-
"""
Sklearn案例⑰:决策树分类(KSC 真实数据)
- 无泄露预处理:仅在训练像素上 fit 标准化与 PCA
- 评估:OA/Kappa/分类报告 + 数字标注混淆矩阵
- 整图预测:逐像素分类可视化(离散色标)
- 决策树结构:仅展示前1层,放大字体、按比例显示
"""

import os
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.colors import ListedColormap, BoundaryNorm

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import (confusion_matrix, classification_report,
                             accuracy_score, cohen_kappa_score)

# ===== 中文显示 =====
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False

# ===== 参数区(仅需修改 DATA_DIR)=====
DATA_DIR   = r"your_path"   # ←← 修改为你的 KSC 数据路径(包含 KSC.mat / KSC_gt.mat)
PCA_DIM    = 30
TRAIN_SIZE = 0.3
SEED       = 42

# ===== 1) 读取 KSC 数据(仅取有标签像素)=====
X_cube = sio.loadmat(os.path.join(DATA_DIR, "KSC.mat"))["KSC"].astype(np.float32)   # (H,W,B)
Y_map  = sio.loadmat(os.path.join(DATA_DIR, "KSC_gt.mat"))["KSC_gt"].astype(int)     # (H,W)
h, w, b = X_cube.shape

coords = np.argwhere(Y_map != 0)              # 有标签像素坐标
X_all  = X_cube[coords[:, 0], coords[:, 1]]   # (N,B)
y_all  = Y_map[coords[:, 0], coords[:, 1]] - 1
num_classes = int(y_all.max() + 1)
print(f"[INFO] 有标签像素: {len(y_all)}, 类别数: {num_classes}")

# ===== 2) 分层划分(仅在有标签像素上)=====
X_tr, X_te, y_tr, y_te = train_test_split(
    X_all, y_all,
    train_size=TRAIN_SIZE,
    stratify=y_all,
    random_state=SEED
)

# ===== 3) 无泄露预处理:仅用训练像素拟合 =====
scaler = StandardScaler().fit(X_tr)
pca    = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(X_tr))

X_train = pca.transform(scaler.transform(X_tr))
X_test  = pca.transform(scaler.transform(X_te))

# ===== 4) 决策树模型 =====
clf = DecisionTreeClassifier(
    criterion="gini",       # 或 "entropy"
    max_depth=20,           # 适度限制深度防过拟合
    min_samples_split=5,
    random_state=SEED
)
clf.fit(X_train, y_tr)
y_pred = clf.predict(X_test)

# ===== 5) 评估与报告 =====
oa    = accuracy_score(y_te, y_pred)
kappa = cohen_kappa_score(y_te, y_pred)
cm    = confusion_matrix(y_te, y_pred, labels=np.arange(num_classes))
cmn   = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1)

print(f"OA = {oa*100:.2f}%   Kappa = {kappa:.4f}")
print(classification_report(y_te, y_pred, digits=4, zero_division=0))

# ===== 6) 混淆矩阵(计数 & 归一化,含数字)=====
fig, axes = plt.subplots(1, 2, figsize=(11, 4.2), constrained_layout=True)

im0 = axes[0].imshow(cm, cmap=plt.cm.YlGnBu, interpolation='nearest', alpha=0.9)
axes[0].set_title("混淆矩阵(计数)")
axes[0].set_xlabel("预测"); axes[0].set_ylabel("真实")
for i in range(num_classes):
    for j in range(num_classes):
        axes[0].text(j, i, str(cm[i, j]),
                     ha='center', va='center',
                     color='black' if cm[i, j] < cm.max()/2 else 'white', fontsize=9)
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

im1 = axes[1].imshow(cmn, vmin=0, vmax=1, cmap=plt.cm.OrRd, interpolation='nearest', alpha=0.9)
axes[1].set_title("混淆矩阵(归一化)")
axes[1].set_xlabel("预测"); axes[1].set_ylabel("真实")
for i in range(num_classes):
    for j in range(num_classes):
        axes[1].text(j, i, f"{cmn[i, j]*100:.1f}%",
                     ha='center', va='center', color='black', fontsize=9)
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

plt.show()

# ===== 7) 整图像素级预测(逐像素分类)=====
X_flat      = X_cube.reshape(-1, b)
X_flat_pca  = pca.transform(scaler.transform(X_flat))
pred_map    = clf.predict(X_flat_pca).reshape(h, w) + 1   # 转 1..C 便于显示

# 离散色表(1..C),背景本步骤不单独处理
base_cmap   = plt.get_cmap('tab20')
colors      = [base_cmap(i % 20) for i in range(num_classes)]
cmap        = ListedColormap(colors)
boundaries  = np.arange(0.5, num_classes + 1.5, 1)
norm        = BoundaryNorm(boundaries, cmap.N)

plt.figure(figsize=(8.6, 6.4))
im = plt.imshow(pred_map, cmap=cmap, norm=norm, interpolation='nearest')
plt.title("KSC 整图预测结果(决策树)")
plt.axis('off')
cbar = plt.colorbar(im, boundaries=boundaries,
                    ticks=np.arange(1, num_classes + 1, max(1, num_classes // 12)),
                    fraction=0.046, pad=0.04)
cbar.set_label("类别ID", rotation=90)
plt.show()

# ===== 8) 决策树结构可视化(前两层,放大字体、按比例显示)=====
plt.figure(figsize=(16, 8))
plot_tree(
    clf,
    filled=True,
    feature_names=[f"PC{i+1}" for i in range(PCA_DIM)],
    class_names=[str(i+1) for i in range(num_classes)],
    max_depth=1,        # 仅展示前两层
    fontsize=10,        # 放大字体
    proportion=True     # 节点以类别比例显示,信息更简洁
)
plt.title("决策树前1层结构示意(按比例显示)", fontsize=14)
plt.show()

🔍 结果解读

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  1. 性能指标

    • 控制台输出:OA、Kappa、精确率、召回率和 F1。
    • 混淆矩阵:可以直观看出哪些类别容易混淆。
  2. 整图预测

    • 每个像素都会被分配一个类别,直观呈现分类效果。
    • 由于单棵树容易过拟合,分类图可能存在块状噪声。
  3. 树结构可视化

    • 决策树可以直接画出来(示例只画前两层),展示分类规则,非常适合教学与解释。

✅ 总结

  • 决策树是一个直观、易解释的模型,在遥感分类中可作为 入门基线

  • 它的主要超参数包括:

    • criterion(划分标准):ginientropy
    • max_depth(最大深度):防止过拟合的重要参数;
    • min_samples_split(最小分裂样本数):控制节点继续划分的条件。
  • 实际应用中,单棵树的泛化能力有限,更推荐使用 集成方法(随机森林、梯度提升树)

欢迎大家关注下方公众号获取更多内容!

您可能感兴趣的与本文相关的镜像

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

遥感AI实战

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值