系列回顾:
(一)CNN 基础:高光谱图像分类可视化全流程,链接:https://mp.weixin.qq.com/s/P4IOG0WTDuoBEprfGSWMTQ
(二)HybridNet(CNN+Transformer):提升全局感受野,链接:https://mp.weixin.qq.com/s/Zlev4Z0b3VE7a6jOOpzaAA
(三)GCN 入门实战:基于光谱 KNN 的图卷积分类与全图预测,链接:https://mp.weixin.qq.com/s/Vo5QNA7gkqbYhYg10krbnQ
(四)空间–光谱联合构图的 GCN:RBF 边权 + 自环 + 早停,得到更稳更自然的全图分类结果,链接:https://mp.weixin.qq.com/s/G7VnMzhby4Fvmjwh_R_z7A
本篇(五):在(四)的基础上,加入 GAT(注意力图网络)与构图消融,并实现分块全图预测防 OOM,确保在显存有限的设备上也能稳定跑完全图。
一、这篇要解决什么问题?
-
构图到底该怎么选?
只用光谱?只用空间?还是融合?我们做一键消融:pure_spectral / pure_spatial / fused。 -
GAT vs. GCN?
GCN需要显式边权(连续相似度);GAT让模型自己学习“看谁更重要”。本篇支持一键切换:MODEL='GCN' | 'GAT'。 -
全图预测容易爆显存?
一次性为整幅图构 KNN 图很吃内存。本篇提供分块局部构图 + 内核拼接,用BLOCK_SIZE / OVERLAP控制,显著降低内存占用。
二、这篇都做了哪些事情?
- 构图消融:光谱KNN、空间KNN、空间–光谱融合(RBF边权 + 对称化 + 无向图 + 自环)
- 模型双模:GCN(显式传 edge_weight) / GAT(连通性图即可)
- 早停 + 固定种子:可复现且节省时间
- 分块全图预测:局部构图与推理,仅写回“内核区域”,减少边界效应与显存压力
- 可视化:混淆矩阵、各类别准确率条形图、全图分类图
三、方法解释
3.1 构图(含消融)
- 光谱:在 PCA 特征空间做 KNN,距离 → 中位数归一 → RBF 相似
- 空间:在像素坐标平面做 KNN,距离 → 中位数归一 → RBF 相似
- 融合:
S = α·S_spec + (1-α)·S_spat,然后对称化并转无向图,最后加自环 - GCN:使用连续边权;GAT:使用连通性图(不需要权重)
3.2 分块全图预测
- 将图像划成重叠块(
BLOCK_SIZE,OVERLAP),对每个块单独构图并推理 - 仅把去掉 overlap 的内核区域写回全图,保证拼接平滑
- 极端未覆盖像素兜底为 0 类(可改成最近邻填补)
3.3 结果展示


四、完整可运行代码(含详细注释)
直接复制运行。按需修改
DATA_DIR / X_FILE / Y_FILE。
默认数据:KSC。依赖:torch,torch_geometric,scikit-learn,matplotlib,seaborn,scipy.
# -*- coding: utf-8 -*-
"""
Series (V): GAT & Graph Construction Ablation for HSI Classification (PyTorch Geometric)
- 解决内存溢出问题(全图分块处理:局部构图 + 内核拼接)
- 邻接矩阵强制对称化、无向化;GCN用连续RBF边权,GAT用连通性
- 完善早停逻辑与复现性
- 构图消融:pure_spectral / pure_spatial / fused
- 可视化:混淆矩阵、各类别准确率、全图预测
"""
import os
import numpy as np
import scipy.io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import add_self_loops, to_undirected
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.neighbors import kneighbors_graph
from sklearn.model_selection import train_test_split
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
# ----------------- 全局样式 -----------------
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")
# ----------------- 可配置超参(可一键消融&切换) -----------------
DATA_DIR = r"your_path"
X_FILE = "KSC.mat"
Y_FILE = "KSC_gt.mat"
# PCA
PCA_DIM = 15 # 若设为0,则自动选择解释方差≥PCA_VAR_THRESHOLD的维度
PCA_VAR_THRESHOLD = 0.95
# 构图消融:'pure_spectral' | 'pure_spatial' | 'fused'
GRAPH_MODE = 'fused'
# 模型:'GCN' | 'GAT'
MODEL = 'GAT'
# KNN与RBF
K_SPEC = 6
K_SPAT = 6
ALPHA = 0.7
SIGMA_SPEC = 1.0
SIGMA_SPAT = 1.0
SELF_LOOP_VALUE = 1.0
# 训练
HIDDEN = 64
DROPOUT = 0.5
LR = 0.01
WD = 5e-4
MAX_EPOCHS = 200
PATIENCE = 15
TRAIN_RATIO = 0.3
SEED = 42
# GAT参数
GAT_HEADS = 4
GAT_CONCAT = True
# 分块全图预测
BLOCK_SIZE = 128
OVERLAP = 32
# ----------------- 复现性 -----------------
def set_seeds(seed=42):
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seeds(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"当前设备:{
device}")
# =========================
# 数据加载与预处理
# =========================
def load_data(x_path, y_path):
"""加载高光谱数据和标签,自动识别键名"""
x_data = scipy.io.loadmat(x_path)
y_data = scipy.io.loadmat(y_path)
x_key = [k for k in x_data.keys() if not k.startswith('__')][0]
y_key = [k for k in y_data.keys() if not k.startswith('__')][0]
return x_data[x_key], y_data[y_key]
print("加载数据...")
X_image, y_image = load_data(os.path.join(DATA_DIR, X_FILE), os.path.join(DATA_DIR, Y_FILE))
h, w, bands = X_image.shape
print(f"图像尺寸: {
h} x {
w}, 波段数: {
bands}")
# 展平 & 标签处理
X_flat = X_image.reshape(-1, bands)
y_flat = y_image.reshape(-1).astype(np.int32)
mask_labeled = y_flat != 0
X_labeled = X_flat[mask_labeled]
y_labeled = (y_flat[mask_labeled] - 1).astype(np.int64)
num_classes = int(len(np.unique(y_labeled)))
print(f"有效样本: {
len(y_labeled)},类别数: {
num_classes}")
# 空间坐标(全图&标注)
rows_all = np.repeat(np.arange(h), w)
cols_all = np.tile(np.arange(w), h)
coords_all = np.stack([rows_all, cols_all], axis=1).astype(np.float32)
coords_labeled = coords_all[mask_labeled]
# 标准化 + PCA(仅在标注样本上fit;全图transform)

最低0.47元/天 解锁文章
1277

被折叠的 条评论
为什么被折叠?



