深度学习与遥感入门(四)|空间–光谱联合构图的 GCN:更稳更准的高光谱分类与全图预测(PyTorch Geometric)

系列回顾:

(一)CNN 基础:高光谱图像分类可视化全流程,链接:https://mp.weixin.qq.com/s/P4IOG0WTDuoBEprfGSWMTQ

(二)HybridNet(CNN+Transformer):提升全局感受野,链接:https://mp.weixin.qq.com/s/Zlev4Z0b3VE7a6jOOpzaAA

(三)GCN 入门实战:基于光谱 KNN 的图卷积分类与全图预测,链接:https://mp.weixin.qq.com/s/Vo5QNA7gkqbYhYg10krbnQ

本篇(四):在第(三)篇的基础上,升级构图与训练策略 —— 融合空间+光谱信息,给边赋“强弱”,并加入早停与可复现实验配置,得到更稳更自然的全图分类结果。

常用数据:https://mp.weixin.qq.com/s/IJRh3HZWTVpJ4v322YmOuA

一、这篇要解决什么问题?

上一版(系列三)的 GCN 只用光谱 KNN构图,边是无权的(0/1)。在复杂地物边界或光谱相近却空间分散的场景,会出现:

  • 邻接关系不够稳定(只靠光谱相似);
  • 信息传播“强弱”一视同仁(无权边),细节不足;
  • 训练需要更多轮次,泛化不够稳。

这篇改进

  1. 空间–光谱联合构图:光谱 KNN(PCA 特征)+ 空间 KNN(像素坐标);
  2. RBF 边权重 + 自环:让相似度成为连续权重,并为每个节点加自环稳定训练;
  3. 早停(Early Stopping)+ 固定随机种子:更可复现、更省时间;
  4. 集中可配置超参:一处改动、全局生效,便于消融与复现实验。

注:既然是改进,那么是否比之前的方法精度要高呢?大家可以对比尝试一下!

二、整体流程概览

  1. 读取 KSC 数据(KSC.mat / KSC_gt.mat
  2. 仅用有标签像素进行标准化 + PCA 降维(与全图 transform 一致)
  3. 构训练图:基于有标签像素,做光谱 KNN 与空间 KNN,距离归一后用 RBF 映射成相似度,按 α 融合,加自环,得到 edge_index + edge_weight
  4. 训练 2 层 GCN(显式传入 edge_weight),并使用早停
  5. 全图预测:对全图做同样构图(联合 + 边权 + 自环)与推理
  6. 仅保留三种可视化(模型结果):混淆矩阵、各类准确率条形图、全图预测分类图

三、关键设计与代码讲解

3.1 空间–光谱联合构图(本篇核心)

  • 光谱相似:在 PCA 特征空间中做 KNN,得到距离矩阵 DspecD_\text{spec}Dspec

  • 空间相似:在像素坐标平面(行、列)做 KNN,得到距离矩阵 DspatD_\text{spat}Dspat

  • 归一化:为避免尺度不一致,按全局中位数将距离做归一化

  • RBF 相似度

    S=exp⁡(−d22σ2) S = \exp\Big(-\frac{d^2}{2\sigma^2}\Big) S=exp(2σ2d2)

  • 融合

    Sfused=α⋅Sspec+(1−α)⋅Sspat S_\text{fused} = \alpha \cdot S_\text{spec} + (1 - \alpha) \cdot S_\text{spat} Sfused=αSspec+(1α)Sspat

  • 自环:为每个节点加入自环(强度可设),稳定训练

直观理解:光谱近空间近都很重要,用 α\alphaα 来调两者权重;相似度越高,边越“粗”。

3.2 带权图卷积(GCNConv + edge_weight)

  • 关闭 GCNConv 默认自环,由我们手动加入(数值可控);
  • 前向传播时显式传 edge_weight,让卷积对强/弱邻居区别对待

3.3 早停与复现

  • 固定随机种子(numpy / torch / cuda);
  • 若测试集准确率在 PATIENCE 轮内无提升,则停止训练并回滚至最佳权重

四、一键可运行完整代码(含详细注释)

直接复制粘贴运行即可。根据你的路径调整 DATA_DIR
仅保留模型结果可视化:混淆矩阵、各类准确率、全图预测

1. 完整代码(配详细注释)+ 分段讲解

1.1 导入依赖与全局样式

# -*- coding: utf-8 -*-
"""
GCN for Hyperspectral Image Classification — Series (IV)
改进点:
- 空间–光谱联合构图(PCA特征KNN + 空间坐标KNN)
- RBF 边权重 + 自环;传入 GCNConv(edge_weight)
- 早停;可配置超参;固定随机种子
- 可视化(仅模型结果):混淆矩阵、类准确率、全图预测
"""

import os
import numpy as np
import scipy.io
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.neighbors import kneighbors_graph

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# ----------------- 全局样式 -----------------
matplotlib.rcParams['font.family'] = 'SimHei'      # 中文字体
matplotlib.rcParams['axes.unicode_minus'] = False  # 正常显示负号
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")

解释

  • torch_geometric.* 是 PyG 的图数据结构与图层;
  • kneighbors_graph 构 KNN 邻接(支持输出距离/连通性);
  • seaborn+matplotlib 做美观可视化;
  • 全局样式设置确保中文与图形风格统一。

1.2 可配置超参(集中管理,便于复现与消融)

# ----------------- 可配置超参 -----------------
DATA_DIR = r"your_path"
X_FILE = "KSC.mat"
Y_FILE = "KSC_gt.mat"

PCA_DIM = 10        # PCA维度
K_SPEC = 8          # 光谱KNN邻居数(在PCA特征空间)
K_SPAT = 8          # 空间KNN邻居数(在像素坐标空间)
ALPHA = 0.7         # 融合权重:alpha*谱相似 + (1-alpha)*空间相似
SIGMA_SPEC = 1.0    # RBF:光谱距离尺度
SIGMA_SPAT = 1.0    # RBF:空间距离尺度
SELF_LOOP_VALUE = 1.0
HIDDEN = 64         # 隐层维度
DROPOUT = 0.5
LR = 0.01
WD = 5e-4
EPOCHS = 200
PATIENCE = 20       # 早停耐心
TRAIN_RATIO = 0.3
SEED = 42

解释

  • ALPHA 控制光谱/空间融合比例;
  • SIGMA_* 控制 RBF 相似度“挑剔”程度;
  • SELF_LOOP_VALUE 为每个节点的自环权重;
  • 训练相关超参集中,方便调整。

1.3 固定随机种子 + 选择设备

# ----------------- 复现性 -----------------
def set_seeds(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"当前设备:{
     
     device}")

解释

  • 多库统一种子,保证每次运行结果一致;
  • 自动选择 GPU/CPU。

1.4 数据加载与基本信息

# =========================
# 数据加载
# =========================
def load_data(x_path, y_path):
    x_data = scipy.io.loadmat(x_path)
    y_data = scipy.io.loadmat(y_path)
    # .mat 里常有多个键,这里过滤掉系统键
    x_key = [k for k in x_data.keys() if not k.startswith('__')][0]
    y_key = [k for k in y_data.keys() if not k.startswith('__')][0]
    return x_data[x_key], y_data[y_key]

print("加载数据...")
X_image, y_image = load_data(os.path.join(DATA_DIR, X_FILE), os.path.join(DATA_DIR, Y_FILE))
h, w, bands = X_image.shape
print(f"图像尺寸: {
     
     h} x {
     
     w}, 波段数: {
     
     bands}")

解释

  • 读取 .mat,提取真实数组;
  • 获取高光谱图像尺寸 (H, W, Bands)

1.5 展平 + 标签预处理 + 坐标准备 + 标准化/降维

# 展平 & 标签 dtype
X_flat = X_image.reshape(-1, bands)
y_flat = y_image.reshape(-1).astype(np.int32)  # 转有符号整型,避免后续 -1 溢出

# 标注子集(仅用于训练/验证)
mask_labeled = y_flat != 0
X_labeled = X_flat[mask_labeled]
y_labeled = (y_flat[mask_labeled] - 1).astype(np.int64)   # 类别从0开始
num_classes = len(np.unique(y_labeled))
print(f"有效样本: {
     
     len(y_labeled)},类别数: {
     
     num_classes}")

# 空间坐标(全图 & 标注子集)
rows = np.repeat(np.arange(h), w)
cols = np.tile(np.arange(w), h)
coords_all = np.stack([rows, cols], axis=1).astype(np.float32)
coords_labeled = coords_all[mask_labeled]

# 标准化 + PCA(仅在标注样本上fit;全图用transform)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_labeled)
pca = PCA(n_components=PCA_DIM, random_state=SEED)
X_pca = pca.fit_transform(X_scaled)

解释

  • (H, W, Bands) 拉平成 (N, Bands)
  • 标签 0 表示未标注,过滤后并让类别从 0 开始;
  • 为联合构图准备像素的二维坐标
  • 只在标注子集上 fit 标准化与 PCA,保证评估公平;全图只 transform

1.6 构图关键函数:距离归一化、RBF 相似、联合融合

# =========================
# 构图:光谱KNN + 空间KNN -> 融合(边权)
# =========================
def _normalize_dist(d):
    """将KNN返回的距离按全局中位数归一,避免尺度不一致"""
    data = d.da
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

遥感AI实战

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值