解决PyG 报错 from torch_geometric.nn.pool.topk_pool import topk, filter_adj

文章讲述了在使用PyTorch的PyG库构建图神经网络时遇到关于topk_pool模块导入错误的问题,原因在于版本更新导致的语法变化。作者提供了通过替换和自定义函数解决此问题的方法,以及相关源码分析。
该文章已生成可运行项目,

问题:

使用Pytorch 的 PyG 搭建 图神经网络 报错

can not import topk, filter_adj from torch_geometric.nn.pool.topk_pool 

解决

版本问题 语法变化
topk => SelectTopk
filter_adj => FilterEdges

from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK

发现替换后不可以
于是进去看SelectTopK\FilterEdges 源码
发现里面有 topk, filter_adj 方法 但是直接 import 也不能用
于是手动写函数出来再 layers.py 里即可运行

def topk(
        x: Tensor,
        ratio: Optional[Union[float, int]],
        batch: Tensor,
        min_score: Optional[float] = None,
        tol: float = 1e-7,
) -> Tensor:
    if min_score is not None:
        # Make sure that we do not drop all nodes in a graph.
        scores_max = scatter(x, batch, reduce='max')[batch] - tol
        scores_min = scores_max.clamp(max=min_score)

        perm = (x > scores_min).nonzero().view(-1)
        return perm

    if ratio is not None:
        num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')

        if ratio >= 1:
            k = num_nodes.new_full((num_nodes.size(0),), int(ratio))
        else:
            k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

        x, x_perm = torch.sort(x.view(-1), descending=True)
        batch = batch[x_perm]
        batch, batch_perm = torch.sort(batch, descending=False, stable=True)

        arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)
        ptr = cumsum(num_nodes)
        batched_arange = arange - ptr[batch]
        mask = batched_arange < k[batch]

        return x_perm[batch_perm[mask]]


def filter_adj(
        edge_index: Tensor,
        edge_attr: Optional[Tensor],
        node_index: Tensor,
        cluster_index: Optional[Tensor] = None,
        num_nodes: Optional[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    if cluster_index is None:
        cluster_index = torch.arange(node_index.size(0),
                                     device=node_index.device)

    mask = node_index.new_full((num_nodes,), -1)
    mask[node_index] = cluster_index

    row, col = edge_index[0], edge_index[1]
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]

    if edge_attr is not None:
        edge_attr = edge_attr[mask]

    return torch.stack([row, col], dim=0), edge_attr

参考官方文档

https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/pool/topk_pool.html

本文章已经生成可运行项目
请帮我修改:import os import torch import torch.nn.functional as F import matplotlib.pyplot as plt import pandas as pd import numpy as np import random from torch_geometric.nn import GCNConv from torch_geometric.explain import Explainer, GNNExplainer, ExplainerConfig from torch_geometric.explain.algorithm import GNNExplainer as GNNExplainerAlgo from sklearn.model_selection import KFold from sklearn.metrics import mean_squared_error, r2_score from sklearn.preprocessing import StandardScaler from matplotlib import cm # 固定随机种子 def set_seed(seed=42): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) # 加载数据 data = torch.load(r"C:\Users\leots\OneDrive\Documents\BaiduSyncdisk\研究\街道色彩\gcn_data.pt", map_location=torch.device("cpu")) # 标准化 y 值 scaler_y = StandardScaler() data.y = data.y.view(-1, 1) data.y = torch.tensor(scaler_y.fit_transform(data.y), dtype=torch.float) # 定义模型 class GCN(torch.nn.Module): def __init__(self, input_dim, hidden_dim=128): super(GCN, self).__init__() self.conv1 = GCNConv(input_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, hidden_dim // 2) self.conv3 = GCNConv(hidden_dim // 2, 1) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) return self.conv3(x, edge_index) # K-Fold GCN 训练 def kfold_gcn(data, scaler_y, k=10, epochs=200, lr=0.01): kf = KFold(n_splits=k, shuffle=True, random_state=42) all_rmse, all_r2 = [], [] node_indices = np.arange(data.num_nodes) for fold, (train_idx, test_idx) in enumerate(kf.split(node_indices)): print(f"\n📂 Fold {fold+1}/{k}") data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) data.train_mask[train_idx] = True data.test_mask[test_idx] = True model = GCN(input_dim=data.x.shape[1]) optimizer = torch.optim.Adam(model.parameters(), lr=lr) loss_fn = torch.nn.MSELoss() for epoch in range(epochs): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index).squeeze() loss = loss_fn(out[data.train_mask], data.y[data.train_mask].squeeze()) loss.backward() optimizer.step() model.eval() with torch.no_grad(): out = model(data.x, data.edge_index).squeeze() y_pred = out[data.test_mask].cpu().numpy() y_true = data.y[data.test_mask].cpu().numpy().squeeze() y_pred_rescaled = scaler_y.inverse_transform(y_pred.reshape(-1, 1)).flatten() y_true_rescaled = scaler_y.inverse_transform(y_true.reshape(-1, 1)).flatten() rmse = mean_squared_error(y_true_rescaled, y_pred_rescaled) ** 0.5 r2 = r2_score(y_true_rescaled, y_pred_rescaled) all_rmse.append(rmse) all_r2.append(r2) print(f"Fold {fold+1} RMSE: {rmse:.4f} | R²: {r2:.4f}") print("\n📊 K-Fold Summary:") print(f"Avg RMSE: {np.mean(all_rmse):.4f} ± {np.std(all_rmse):.4f}") print(f"Avg R²: {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}") return model # 解释节点函数(新API) def explain_nodes_and_save(model, data, scaler_y, top_k=5, top_feats=10, save_dir="explain_results"): os.makedirs(save_dir, exist_ok=True) model.eval() # 配置解释器(新版接口) explainer = Explainer( model=model, algorithm=GNNExplainerAlgo(epochs=200), explanation_type='model', node_mask_type='attributes', edge_mask_type='object' ) with torch.no_grad(): Y_pred = model(data.x, data.edge_index).squeeze() Y_pred_rescaled = scaler_y.inverse_transform(Y_pred.cpu().numpy().reshape(-1, 1)).flatten() top_nodes = np.argsort(Y_pred_rescaled)[-top_k:] feature_names = [f"feat_{i}" for i in range(data.x.shape[1])] for node_idx in top_nodes: print(f"🧠 正在解释节点 {node_idx} ...") explanation = explainer(data.x, data.edge_index, index=node_idx) # 子图图像 ax, _ = explanation.visualize_graph() plt.title(f"Subgraph Explanation for Node {node_idx}") plt.savefig(os.path.join(save_dir, f"subgraph_node{node_idx}.png")) plt.close() # 特征重要性 importance = explanation.node_mask[node_idx].cpu().numpy() top_feats_idx = importance.argsort()[-top_feats:][::-1] top_importance = [(feature_names[i], importance[i]) for i in top_feats_idx] feat_df = pd.DataFrame(top_importance, columns=['Feature', 'Importance']) plt.figure(figsize=(8, 4)) plt.barh(feat_df['Feature'], feat_df['Importance'], color=cm.viridis(np.linspace(0.2, 0.8, top_feats))) plt.gca().invert_yaxis() plt.title(f"Top {top_feats} Feature Importance for Node {node_idx}") plt.tight_layout() plt.savefig(os.path.join(save_dir, f"feature_node{node_idx}.png")) plt.close() feat_df.to_csv(os.path.join(save_dir, f"feature_node{node_idx}.csv"), index=False) print(f"\n✅ 所有解释完成,图像和 CSV 已保存至:{save_dir}") # 执行 model = kfold_gcn(data, scaler_y) explain_nodes_and_save(model, data, scaler_y)
最新发布
06-02
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zoe_ya

如果你成功申请,可以打赏杯奶茶

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值