深度学习与遥感入门(五)|GAT & 构图消融 + 分块全图预测:更稳更快的高光谱图分类(PyTorch Geometric 实战)

系列回顾:

(一)CNN 基础:高光谱图像分类可视化全流程,链接:https://mp.weixin.qq.com/s/P4IOG0WTDuoBEprfGSWMTQ
(二)HybridNet(CNN+Transformer):提升全局感受野,链接:https://mp.weixin.qq.com/s/Zlev4Z0b3VE7a6jOOpzaAA
(三)GCN 入门实战:基于光谱 KNN 的图卷积分类与全图预测,链接:https://mp.weixin.qq.com/s/Vo5QNA7gkqbYhYg10krbnQ
(四)空间–光谱联合构图的 GCN:RBF 边权 + 自环 + 早停,得到更稳更自然的全图分类结果,链接:https://mp.weixin.qq.com/s/G7VnMzhby4Fvmjwh_R_z7A
本篇(五):在(四)的基础上,加入 GAT(注意力图网络)与构图消融,并实现分块全图预测防 OOM,确保在显存有限的设备上也能稳定跑完全图。

一、这篇要解决什么问题?

  1. 构图到底该怎么选?
    只用光谱?只用空间?还是融合?我们做一键消融pure_spectral / pure_spatial / fused

  2. GAT vs. GCN?
    GCN需要显式边权(连续相似度);GAT让模型自己学习“看谁更重要”。本篇支持一键切换MODEL='GCN' | 'GAT'

  3. 全图预测容易爆显存?
    一次性为整幅图构 KNN 图很吃内存。本篇提供分块局部构图 + 内核拼接,用 BLOCK_SIZE / OVERLAP 控制,显著降低内存占用

二、这篇都做了哪些事情?

  • 构图消融:光谱KNN、空间KNN、空间–光谱融合(RBF边权 + 对称化 + 无向图 + 自环)
  • 模型双模:GCN(显式传 edge_weight) / GAT(连通性图即可)
  • 早停 + 固定种子:可复现且节省时间
  • 分块全图预测:局部构图与推理,仅写回“内核区域”,减少边界效应与显存压力
  • 可视化:混淆矩阵、各类别准确率条形图、全图分类图

三、方法解释

3.1 构图(含消融)

  • 光谱:在 PCA 特征空间做 KNN,距离 → 中位数归一 → RBF 相似
  • 空间:在像素坐标平面做 KNN,距离 → 中位数归一 → RBF 相似
  • 融合:S = α·S_spec + (1-α)·S_spat,然后对称化并转无向图,最后加自环
  • GCN:使用连续边权;GAT:使用连通性图(不需要权重)

3.2 分块全图预测

  • 将图像划成重叠块(BLOCK_SIZEOVERLAP),对每个块单独构图并推理
  • 仅把去掉 overlap 的内核区域写回全图,保证拼接平滑
  • 极端未覆盖像素兜底为 0 类(可改成最近邻填补)

3.3 结果展示

在这里插入图片描述
在这里插入图片描述

四、完整可运行代码(含详细注释)

直接复制运行。按需修改 DATA_DIR / X_FILE / Y_FILE
默认数据:KSC。依赖:torch, torch_geometric, scikit-learn, matplotlib, seaborn, scipy.

# -*- coding: utf-8 -*-
"""
Series (V): GAT & Graph Construction Ablation for HSI Classification (PyTorch Geometric)
- 解决内存溢出问题(全图分块处理:局部构图 + 内核拼接)
- 邻接矩阵强制对称化、无向化;GCN用连续RBF边权,GAT用连通性
- 完善早停逻辑与复现性
- 构图消融:pure_spectral / pure_spatial / fused
- 可视化:混淆矩阵、各类别准确率、全图预测
"""

import os
import numpy as np
import scipy.io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import add_self_loops, to_undirected

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.neighbors import kneighbors_graph
from sklearn.model_selection import train_test_split

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# ----------------- 全局样式 -----------------
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")

# ----------------- 可配置超参(可一键消融&切换) -----------------
DATA_DIR = r"your_path"
X_FILE = "KSC.mat"
Y_FILE = "KSC_gt.mat"

# PCA
PCA_DIM = 15                 # 若设为0,则自动选择解释方差≥PCA_VAR_THRESHOLD的维度
PCA_VAR_THRESHOLD = 0.95

# 构图消融:'pure_spectral' | 'pure_spatial' | 'fused'
GRAPH_MODE = 'fused'

# 模型:'GCN' | 'GAT'
MODEL = 'GAT'

# KNN与RBF
K_SPEC = 6
K_SPAT = 6
ALPHA   = 0.7
SIGMA_SPEC = 1.0
SIGMA_SPAT = 1.0
SELF_LOOP_VALUE = 1.0

# 训练
HIDDEN = 64
DROPOUT = 0.5
LR = 0.01
WD = 5e-4
MAX_EPOCHS = 200
PATIENCE = 15
TRAIN_RATIO = 0.3
SEED = 42

# GAT参数
GAT_HEADS = 4
GAT_CONCAT = True

# 分块全图预测
BLOCK_SIZE = 128
OVERLAP = 32

# ----------------- 复现性 -----------------
def set_seeds(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seeds(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"当前设备:{
     
     device}")

# =========================
# 数据加载与预处理
# =========================
def load_data(x_path, y_path):
    """加载高光谱数据和标签,自动识别键名"""
    x_data = scipy.io.loadmat(x_path)
    y_data = scipy.io.loadmat(y_path)
    x_key = [k for k in x_data.keys() if not k.startswith('__')][0]
    y_key = [k for k in y_data.keys() if not k.startswith('__')][0]
    return x_data[x_key], y_data[y_key]

print("加载数据...")
X_image, y_image = load_data(os.path.join(DATA_DIR, X_FILE), os.path.join(DATA_DIR, Y_FILE))
h, w, bands = X_image.shape
print(f"图像尺寸: {
     
     h} x {
     
     w}, 波段数: {
     
     bands}")

# 展平 & 标签处理
X_flat = X_image.reshape(-1, bands)
y_flat = y_image.reshape(-1).astype(np.int32)
mask_labeled = y_flat != 0
X_labeled = X_flat[mask_labeled]
y_labeled = (y_flat[mask_labeled] - 1).astype(np.int64)
num_classes = int(len(np.unique(y_labeled)))
print(f"有效样本: {
     
     len(y_labeled)},类别数: {
     
     num_classes}")

# 空间坐标(全图&标注)
rows_all = np.repeat(np.arange(h), w)
cols_all = np.tile(np.arange(w), h)
coords_all = np.stack([rows_all, cols_all], axis=1).astype(np.float32)
coords_labeled = coords_all[mask_labeled]

# 标准化 + PCA(仅在标注样本上fit;全图transform)
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

遥感AI实战

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值