解决PyG 报错 from torch_geometric.nn.pool.topk_pool import topk, filter_adj

文章讲述了在使用PyTorch的PyG库构建图神经网络时遇到关于topk_pool模块导入错误的问题,原因在于版本更新导致的语法变化。作者提供了通过替换和自定义函数解决此问题的方法,以及相关源码分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

问题:

使用Pytorch 的 PyG 搭建 图神经网络 报错

can not import topk, filter_adj from torch_geometric.nn.pool.topk_pool 

解决

版本问题 语法变化
topk => SelectTopk
filter_adj => FilterEdges

from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK

发现替换后不可以
于是进去看SelectTopK\FilterEdges 源码
发现里面有 topk, filter_adj 方法 但是直接 import 也不能用
于是手动写函数出来再 layers.py 里即可运行

def topk(
        x: Tensor,
        ratio: Optional[Union[float, int]],
        batch: Tensor,
        min_score: Optional[float] = None,
        tol: float = 1e-7,
) -> Tensor:
    if min_score is not None:
        # Make sure that we do not drop all nodes in a graph.
        scores_max = scatter(x, batch, reduce='max')[batch] - tol
        scores_min = scores_max.clamp(max=min_score)

        perm = (x > scores_min).nonzero().view(-1)
        return perm

    if ratio is not None:
        num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')

        if ratio >= 1:
            k = num_nodes.new_full((num_nodes.size(0),), int(ratio))
        else:
            k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

        x, x_perm = torch.sort(x.view(-1), descending=True)
        batch = batch[x_perm]
        batch, batch_perm = torch.sort(batch, descending=False, stable=True)

        arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)
        ptr = cumsum(num_nodes)
        batched_arange = arange - ptr[batch]
        mask = batched_arange < k[batch]

        return x_perm[batch_perm[mask]]


def filter_adj(
        edge_index: Tensor,
        edge_attr: Optional[Tensor],
        node_index: Tensor,
        cluster_index: Optional[Tensor] = None,
        num_nodes: Optional[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    if cluster_index is None:
        cluster_index = torch.arange(node_index.size(0),
                                     device=node_index.device)

    mask = node_index.new_full((num_nodes,), -1)
    mask[node_index] = cluster_index

    row, col = edge_index[0], edge_index[1]
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]

    if edge_attr is not None:
        edge_attr = edge_attr[mask]

    return torch.stack([row, col], dim=0), edge_attr

参考官方文档

https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/pool/topk_pool.html

### 关于 PyTorch Geometric 中 TransformerConv 的用法与实现细节 `TransformerConv` 是 PyTorch Geometric 提供的一种基于图神经网络 (GNN) 的卷积操作模块,它实现了类似于 Transformer 架构中的自注意力机制来处理图数据。以下是其主要功能和实现细节: #### 1. **基本概念** `TransformerConv` 使用节点特征以及可选的边属性来进行消息传递,并通过自注意力机制动态调整不同邻居节点的重要性权重[^5]。 #### 2. **参数说明** `TransformerConv` 的核心参数如下: - `in_channels`: 输入节点特征的维度。 - `out_channels`: 输出节点特征的维度。 - `heads`: 多头注意力的数量,默认为 1。 - `concat`: 如果设置为 True,则多个头部的结果会被拼接;如果 False,则取平均值。 - `beta`: 是否启用额外的学习参数 β 来控制来自中心节点自身的贡献。 - `edge_dim`: 边特征的维度(如果有)。如果没有提供此参数,则不考虑边特征的影响。 这些参数使得模型能够灵活适应不同的输入场景并增强表达能力[^6]。 #### 3. **前向传播过程** 在前向传播过程中,`TransformerConv` 主要执行以下几个步骤: - 对每个节点及其邻居应用线性变换以生成查询 (`Q`) 和键值对 (`K`, `V`)。 - 利用自定义的距离度量或者边上的附加信息计算注意力分数。 - 将加权后的邻域表示聚合到当前节点上作为最终更新结果。 具体公式可以概括为: \[ \text{Attention}(h_i, h_j) = \frac{\exp(\text{LeakyReLU}(\vec{a}^\top [\mathbf{W}_q h_i || \mathbf{W}_k h_j]))}{\sum_{j' \in N(i)} \exp(\text{LeakyReLU}(\vec{a}^\top [\mathbf{W}_q h_i || \mathbf{W}_k h_{j'}]))}, \] 其中 \(||\) 表示连接操作,\(N(i)\) 表示节点 i 的邻居集合[^7]。 #### 4. **代码示例** 下面是一个简单的例子展示如何使用 `TransformerConv` 层构建 GNN 模型: ```python import torch from torch_geometric.nn import TransformerConv from torch_geometric.data import Data # 定义超参 num_nodes = 5 node_features = torch.randn(num_nodes, 8) # 假设有 5 个节点,每个有 8 维特征 edges = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]], dtype=torch.long).t().contiguous() # 边索引 data = Data(x=node_features, edge_index=edges) # 初始化 TransformerConv 层 conv_layer = TransformerConv(in_channels=8, out_channels=16, heads=2, concat=True, beta=False) # 执行前向传播 output = conv_layer(data.x, data.edge_index) print(output.shape) # 应该打印出 (5, 32),因为有两个 head 被 concatenation ``` 上述代码片段展示了如何初始化一个具有双头注意机制的 `TransformerConv` 并将其应用于给定的数据对象之上[^8]。 --- ###
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zoe_ya

如果你成功申请,可以打赏杯奶茶

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值