GCN torch_geometric utils scatter_方法源码解析及样例

本文详细解析了torch_scatter库中的scatter_方法,该方法用于根据指定索引在特定维度上聚合张量值。通过源码分析,介绍了参数name(聚合方式)、src(源张量)、index(索引)、dim(维度)和dim_size(运算后返回的维度)的用法。并提供了多个示例,展示了不同参数设置下如何进行节点信息聚合,包括加法、均值和最大值最小值计算。该方法常用于图神经网络的消息传递层。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

源码位置

torch_geometric\utils\scatter.py

源码及注释

代码相当于在原有torch_scatter包的基础上进行二次开发,1.4.3之后版本移除了,相当于只要理解torch.scatter(…, reduce=‘add’) 就好了。
下面是开发者说的原话:
“The scatter_ call got removed in one of the more recent versions, and you can now simply use torch.scatter(…, reduce=‘add’) for the same effect.”

import torch_scatter


def scatter_(name, src, index, dim=0, dim_size=None):
    r"""Aggregates all values from the :attr:`src` tensor at the indices
    specified in the :attr:`index` tensor along the first dimension.
    If multiple indices reference the same location, their contributions
    are aggregated according to :attr:`name` (either :obj:`"add"`,
    :obj:`"mean"` or :obj:`"max"`).

    Args:
        name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
            :obj:`"min"`, :obj:`"max"`).
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index. (default: :obj:`0`)
        dim_size (int, optional): Automatically create output tensor with size
            :attr:`dim_size` in the first dimension. If set to :attr:`None`, a
            minimal sized output tensor is returned. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """
    # 例行断言
    assert name in ['add', 'mean', 'min', 'max']

    # 返回对象属性值
    op = getattr(torch_scatter, 'scatter_{}'.format(name))
    # 获取对象属性后返回值可直接使用,看具体案例
    out = op(src, index, dim, None, dim_size)
    out = out[0] if isinstance(out, tuple) else out

    # 限制取最值后数据约束
    if name == 'max':
        out[out < -10000] = 0
    elif name == 'min':
        out[out > 10000] = 0

    return out

torch_scatter方法

结合官网和以下网址大家可以看得懂(看不懂没关系,多看几遍,我也是看了好几遍才懂的,不得不佩服其设计),这里不做过多解释。
官网 Docs > torch.Tensor > torch.Tensor.scatter_

官网Docs » Scatter

PyTorch笔记之scatter()函数的使用

pytorch中的 scatter_()函数使用和详解

scatter_方法解析

常规运行

在理解上述原有开发包的基础上,我们再来看一下此方法的作用。
在上一篇博客里面用一句话介绍过,按传入方式聚合节点信息。
其是用于创建消息传递层的基类。
为了便于理解,这里给出一个具体例子:
假设我们有一个图,图中有四个节点,节点特征维度为3。
其关系如下图所示(为方便理解我画成了单向图):
在这里插入图片描述
以下代码可以在scatter_方法调试中使用:

#src 节点特征
a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
#index 索引
b = torch.tensor([2, 3, 0, 1])
#dim 维度信息
c = 1
#dim_size 返回的维度
e = 4
# op = getattr(torch_scatter, 'scatter_{}'.format(name))
# out = op(src, index, dim, None, dim_size)
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)

tensor([[2., 2., 2.],
        [4., 4., 4.],
        [1., 1., 1.],
        [3., 3., 3.]])

可以看到他进行了一次消息传递,以节点1为例,他接收了来自节点3的消息,所以特征变为[2., 2., 2.]。

节点特征维度

此样例说明节点信息聚合时,各个维度相互独立。

a = torch.tensor([[1.0, 0.0, 1.0], [2.0, 0.0, 2.0], [3.0, 0.0, 3.0], [4.0, 0.0, 4.0]])
b = torch.tensor([2, 3, 0, 1])
c = 1
e = 4
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)

tensor([[2.0, 0.0, 2.0],
        [4.0, 0.0, 4.0],
        [1.0, 0.0, 1.0],
        [3.0, 0.0, 3.0]])

dim_size(运算后返回的维度)参数

只修改dim_size ,可以看到返回维度增加。

a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 3, 0, 1])
c = 1
e = 8
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)

tensor([[2., 2., 2.],
        [4., 4., 4.],
        [1., 1., 1.],
        [3., 3., 3.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

index(索引)参数

这里相当于修改了节点的指向,修改后图:
在这里插入图片描述

a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 0, 0, 1])
c = 1
e = 4
qq = getattr(torch_scatter, 'scatter_{}'.format("add"))
qq(a, b, c, None, e)

tensor([[5., 5., 5.],
        [4., 4., 4.],
        [1., 1., 1.],
        [0., 0., 0.]])

因为这里使用的“add”方法,所以同时传递到一个节点的信息被累加。

name(聚合方式)

很显然,当我把聚合方式改为均值,除了节点1有两个值传入被平均,其他保持不变;同理最值一样道理。

a = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]])
b = torch.tensor([2, 0, 0, 1])
c = 1
e = 4
qq = getattr(torch_scatter, 'scatter_{}'.format("mean"))
qq(a, b, c, None, e)

tensor([[2.5000, 2.5000, 2.5000],
        [4.0000, 4.0000, 4.0000],
        [1.0000, 1.0000, 1.0000],
        [0.0000, 0.0000, 0.0000]])

PS:torch_scatter(函数库)转torch.scatter_add(内置原生函数)

要转换的代码:

deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

scatter_add 是 PyTorch 函数库 torch_scatter 中的一个函数,而 torch.scatter_add 是 PyTorch 内置的原生函数。两者的调用方式略有不同。以下是将 scatter_add 转换为 torch.scatter_add 的重写方法。
假设代码中的参数如下:
edge_weight:包含要累加的数据的张量(源张量)。
row:指定目标索引位置的张量(即 index 参数)。
dim=0:表示累加操作发生在第 0 维。
dim_size=num_nodes:目标张量在第 0 维的大小。

转换代码:

要将 scatter_add 转换为 torch.scatter_add,我们首先需要创建一个目标张量 output,然后在该张量上执行 torch.scatter_add 操作。

import torch

# 假设已知 edge_weight、row 和 num_nodes
edge_weight = torch.tensor([0.5, 1.0, 1.5, 2.0])  # 举例的源张量
row = torch.tensor([0, 1, 0, 2])  # 索引张量
num_nodes = 3  # 输出张量的第 0 维大小

# 创建目标张量 output,初始化为零
output = torch.zeros(num_nodes, dtype=edge_weight.dtype)
# 使用 torch.scatter_add
deg = torch.scatter_add(output,dim=0,index=row,src=edge_weight))

output 是大小为 (num_nodes,) 的张量,初始化为零,以便累加 edge_weight 中的值。
scatter_add(0, row, edge_weight) 表示在 output 张量的第 0 维按 row 中的索引位置累加
edge_weight 中的值。

感谢及参考博文

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
官方说明 Docs > torch.Tensor > torch.Tensor.scatter
https://pytorch.org/docs/master/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
官网说明 Docs » Scatter
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
参考博文1 PyTorch笔记之scatter()函数的使用
https://www.zhangshengrong.com/p/rG1V7g5ka3/
参考博文2 pytorch中的 scatter_()函数使用和详解
https://blog.youkuaiyun.com/t20134297/article/details/105755817
参考GPT4o

### 关于图神经网络GCN) #### 概念 图神经网络(Graph Convolutional Network, GCN)是一种专门用于处理图结构数据的深度学习方法。传统的卷积神经网络(CNN)适用于网格状的数据结构,如图像;而GCN则能够有效处理节点之间存在复杂关系的非欧氏空间数据,即图结构数据[^3]。 #### 原理 GCN的核心在于通过聚合邻居节点的信息来更新当前节点表示。具体来说,在每一层中,每个节点会收集其相邻节点的状态,并基于这些状态计算新的特征向量。这一过程可以通过定义在图上的卷积操作完成,其中涉及到的关键组件之一就是图拉普拉斯矩阵,它是描述图拓扑特性的数学工具[^2]。 对于半监督分类任务而言,GCN利用已标记本指导未标记本的学习过程。该算法假设连接紧密的节点更可能属于同一类别,从而使得标签能够在整个图上传播开来。这种机制允许即使只有少量标注过的实也能训练出性能良好的模型[^1]。 #### 实现 以下是使用PyTorch框架实现简单版本GCN的一个子: ```python import torch from torch.nn import Parameter import torch.nn.functional as F from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import add_remaining_self_loops class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(GCNConv, self).__init__(aggr='add') # "Add" aggregation. self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index, edge_weight=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype, device=edge_index.device) row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=self.lin(x), norm=norm) def propagate(self, edge_index, size, x, norm): src_node_features = x msg = norm.view(-1, 1) * src_node_features[edge_index[0]] aggregated_messages = scatter_add(msg, edge_index[1], dim=0, dim_size=size[1]) return aggregated_messages ``` 这段代码展示了如何构建一个简单的两层GCN架构并应用于节点级别的预测任务。这里`MessagePassing`类来自`torch_geometric`库,提供了方便的消息传递接口以简化编码工作流。 #### 应用 GCNs广泛应用于多个领域,包括但不限于社交网络分析、推荐系统、生物信息学以及化学分子性质预测等场景。如,在药物研发过程中,科学家们可以借助GCN研究化合物之间的相互作用模式,进而加速新药筛选流程。同地,在电子商务平台上,平台运营者也可以运用此类技术优化个性化商品推荐服务。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值