【Torch API】pytorch 中index_add()函数详解

import torch

# create a tensor of zeros with shape (3, 4)
t = torch.zeros((3, 4))

# create an index tensor of shape (2,)
index = torch.tensor([0, 2])

# create a tensor to add with shape (2, 4)
src = torch.ones((2, 4))

# add the src tensor to t at the indices specified in the index tensor
t.index_add(0, index, src)

print(t)

输出结果为:

tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [1., 1., 1., 1.]])

解释:
上述代码,使用了 index_add() 方法将 src 张量添加到 t 张量中。具体而言,下面这行代码:

t.index_add(0, index, src)

完成了相加的操作。下面是每个参数的含义:

0:指定执行索引操作的维度。在本例中,我们要将 src 张量添加到由 index 张量指定的 t 张量的行上,因此设置 dim=0。

index:指定要将 src 张量添加到哪些位置的索引。在本例中,我们要将 src 张量添加到 t 的第 0 行和第 2 行,因此设置 index=torch.tensor([0, 2])。

src:要添加到 t 的张量。在本例中,src 是一个全部为 1 的张量,形状为 (2, 4)。

index_add() 方法的工作原理是将 src 的每一行加到由 index 张量指定的 t 张量的相应行上。在本例中,指定了 index=torch.tensor([0, 2]),因此 src 的第一行(全为 1)会被加到 t 的第一行,第二行(同样全为 1)会被加到 t 的第三行。结果是一个张量,其中第一行和第三行为 1,其它地方为 0。

import scipy.io import torch from torch_geometric.data import Data from torch_geometric.nn import GNNExplainer from torch_geometric.utils import to_networkx import matplotlib.pyplot as plt import networkx as nx from edge_index_to_adj_node_edge import edge_index_to_adj_node_edge # 加载 .mat 文件 mat_file = "D:/GNN/wy0303-62OLD/data/data0318/data0318.mat" # 替换为你的 .mat 文件路径 mat_data = scipy.io.loadmat(mat_file) # 提取数据 edge_index = mat_data['edge_index'] # 边的索引 n_s = mat_data['n_s'] # 源节点特征 n_d = mat_data['n_d'] # 目标节点特征 m_s = mat_data['m_s'] # 可能是边特征 m_d = mat_data['m_d'] # 可能是边特征 E = mat_data['E'] # 可能是一些额外的图特征 sc = mat_data['sc'] # 可能是额外的图特征 adj_matrix = edge_index_to_adj_node_edge(edge_index) # 如果存在邻接矩阵 # 转换为 PyTorch 张量 edge_index_tensor = torch.tensor(edge_index, dtype=torch.long) node_features_tensor = torch.cat((torch.tensor(n_s, dtype=torch.float), torch.tensor(n_d, dtype=torch.float)), dim=0) # 处理节点特征(聚合特征维度,得到一个二维张量) node_features_processed = node_features_tensor.mean(dim=(1, 2, 3)) # 你可以根据需要修改此方法 # 创建 PyTorch Geometric 数据对象 data = Data(x=node_features_processed, edge_index=edge_index_tensor) # 加载预训练的模型 model = torch.load('D:/GNN/wy0303-62OLD/best_model.pt') model.eval() # 设置为评估模式 # 选择一个节点进行解释 node_idx = 0 # 选择图中的一个节点进行解释 # 使用 GNNExplainer 进行解释 explainer = GNNExplainer(model, epochs=200) print(f"edge_index type: {type(data.edge_index)}") print(f"edge_index shape: {data.edge_index.shape}") # 在调用 explain_node 时传递所有额外的输入(包括 edge_index) # 这里我们明确传递 edge_index 给模型和 explainer # n_s, n_d, m_s, m_d, E, sc,edge_index,adj_matrix explanation = explainer.explain_node( node_idx, data.x, data.edge_index, m_d=m_d, m_s=m_s, E=E, sc=sc, adj_matrix=adj_matrix, ) # 可视化结果 G = to_networkx(data) # 转换为 NetworkX 图对象 pos = nx.spring_layout(G) # 获取图的布局 # 可视化图 plt.figure(figsize=(8, 8)) nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500) nx.draw_networkx_edges(G, pos, edgelist=explanation.edge_mask, edge_color='r'
最新发布
04-03
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值