以PyG官方的数据集和示例代码来复现一下这个问题:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec')
data = dataset[0]
print(data)
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
print(out)
输出信息:
HeteroData(
paper={
x=[736389, 128],
year=[736389],
y=[736389],
train_mask=[736389],
val_mask=[736389],
test_mask=[736389]
},
author={ x=[1134649, 128] },
institution={ x=[8740, 128] },
field_of_study={ x=[59965, 128] },
(author, affiliated_with, institution)={ edge_index=[2, 1043998] },
(author, writes, paper)={ edge_index=[2, 7145660] },
(paper, cites, paper)={ edge_index=[2, 5416271] },
(paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }
)
my_env/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour.
warnings.warn(
Traceback (most recent call last):
File "try2.py", line 25, in <module>
model = to_hetero(model, data.metadata(), aggr='sum')
File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 118, in to_hetero
return transformer.transform()
File "env_path/lib/python3.8/site-packages/torch_geometric/nn/fx.py", line 157, in transform
getattr(self, op)(node, node.target, node.name)
File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 294, in call_method
args, kwargs = self.map_args_kwargs(node, key)
File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in map_args_kwargs
args = tuple(_recurse(v) for v in node.args)
File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in <genexpr>
args = tuple(_recurse(v) for v in node.args)
File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 387, in _recurse
raise NotImplementedError
NotImplementedError
可以很容易地看出来,这是由于有一种节点没有入边产生的问题。
解决方案就是使所有节点都有入边。如:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]
print(data)
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
print(out)
将异质图转换为无向图,这样就能得到正常的输出结果:
HeteroData(
paper={
x=[736389, 128],
year=[736389],
y=[736389],
train_mask=[736389],
val_mask=[736389],
test_mask=[736389]
},
author={ x=[1134649, 128] },
institution={ x=[8740, 128] },
field_of_study={ x=[59965, 128] },
(author, affiliated_with, institution)={ edge_index=[2, 1043998] },
(author, writes, paper)={ edge_index=[2, 7145660] },
(paper, cites, paper)={ edge_index=[2, 10792672] },
(paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
(institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
(paper, rev_writes, author)={ edge_index=[2, 7145660] },
(field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)
{'paper': tensor([[-0.8212, -0.2630, -0.7286, ..., 1.1904, 0.1617, -0.5388],
[-1.2484, -0.3707, -1.0336, ..., 0.9618, -0.0373, -0.1125],
[-0.5375, 0.0357, -0.6772, ..., 1.2185, 0.2292, -0.2130],
...,
[-0.9934, -0.2688, -0.9547, ..., 1.3144, 0.1519, -0.2015],
[-1.4711, -0.6607, -0.7509, ..., 2.3383, 0.6815, -1.0679],
[-0.4352, -0.4255, -0.6907, ..., 1.1532, 0.1152, -0.9703]]), 'author': tensor([[-0.2782, 0.1771, 0.4187, ..., -0.5233, -0.2969, 0.2438],
[-0.4543, 0.1019, 0.1637, ..., -0.7748, -0.2809, 0.2598],
[-0.1613, -0.0481, -0.2491, ..., -0.6227, -0.4217, 0.1335],
...,
[-0.4908, 0.2382, 0.2973, ..., -0.7266, -0.2486, 0.6449],
[-0.2819, 0.0125, 0.9843, ..., -1.9652, -0.4280, -0.4842],
[-0.4236, -0.1222, 1.0246, ..., -2.0615, -0.3246, -0.1771]]), 'institution': tensor([[ 0.3911, -1.3527, -0.6624, ..., 0.2732, 0.5270, 0.5756],
[ 0.1512, -0.6687, -0.6516, ..., 0.1482, 0.2535, 0.1935],
[ 0.1933, -1.1643, -0.4936, ..., 0.5382, 0.3407, 0.2199],
...,
[ 0.1489, -0.3021, -0.3390, ..., 0.2690, 0.1571, -0.0781],
[ 0.1855, -0.4848, -0.3205, ..., 0.4728, 0.0659, 0.1500],
[ 0.1724, -0.0682, -0.0894, ..., 0.1189, 0.1230, -0.2249]]), 'field_of_study': tensor([[ 0.1929, -0.5402, -0.5714, ..., -0.4296, 0.4376, -0.0660],
[-0.2281, 0.0773, -0.0486, ..., -0.0544, -0.2894, 0.2706],
[-0.2798, -0.1967, -0.3376, ..., -0.3098, -0.1610, 0.1120],
...,
[ 0.0775, -0.5927, -0.6084, ..., -0.3190, 0.2483, -0.1418],
[ 0.0286, -0.7393, -0.6629, ..., -0.4745, 0.8461, -0.1554],
[-0.0804, -0.5598, -0.8517, ..., -0.2317, 0.3234, -0.0520]])}
该博客介绍了在使用PyTorch Geometric库(PyG)处理异质图时遇到的一个常见问题,即某些节点没有入边导致模型无法正确运行。通过引入`ToUndirected`转换,将图转换为无向图,使得所有节点均有入边,从而解决了这一问题。博主展示了如何在OGB_MAG数据集上应用SAGEConv模型,并给出了转换前后的代码示例和输出结果。
2万+

被折叠的 条评论
为什么被折叠?



