PyTorch Geometric异构图处理:支持多节点类型和边类型
1. 异构图学习的挑战与解决方案
在现实世界的图结构数据中,节点和边往往具有多种类型。例如学术网络包含论文、作者、会议等节点类型,以及写作、引用、发表等边类型。这种异构图(Heterogeneous Graph) 相比传统同构图(Homogeneous Graph)能更准确地建模复杂关系,但也带来了三大核心挑战:
- 类型异构性:不同节点/边类型的特征空间和语义含义差异显著
- 结构复杂性:多类型节点间存在复杂的连接模式,难以用单一卷积核捕捉
- 计算效率:大规模异构图的采样和训练需要专门优化的数据加载策略
PyTorch Geometric(PyG)通过HeteroData数据结构、类型感知的卷积层和高效采样器三大组件,为异构图学习提供了完整解决方案。本文将系统介绍这些核心功能,并通过实战案例展示如何构建生产级异构图模型。
2. 核心数据结构:HeteroData
2.1 多类型数据容器设计
PyG的HeteroData类采用类型分层存储架构,将不同类型的节点和边数据分离管理,同时保持统一的访问接口。其内部结构如图所示:
关键特性:
- 节点类型以字符串标识(如
"paper") - 边类型以三元组
(src_type, rel_type, dst_type)标识(如("author", "writes", "paper")) - 支持类似字典的属性访问(
data['author'].x)和批量收集(data.collect('x'))
2.2 数据初始化与操作
2.2.1 创建HeteroData对象
from torch_geometric.data import HeteroData
import torch
data = HeteroData()
# 添加节点类型
data['paper'].x = torch.randn(1000, 128) # 1000篇论文,128维特征
data['author'].x = torch.randn(500, 64) # 500位作者,64维特征
data['conference'].x = torch.randn(50, 32) # 50个会议,32维特征
# 添加边类型
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 500, (2, 2000))
data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 1000, (2, 5000))
data['paper', 'published_in', 'conference'].edge_index = torch.randint(0, 1000, (2, 1000))
print(data)
# HeteroData(
# paper={ x=[1000, 128] },
# author={ x=[500, 64] },
# conference={ x=[50, 32] },
# (author, writes, paper)={ edge_index=[2, 2000] },
# (paper, cites, paper)={ edge_index=[2, 5000] },
# (paper, published_in, conference)={ edge_index=[2, 1000] }
# )
2.2.2 核心方法与属性
| 方法/属性 | 功能描述 |
|---|---|
node_types | 返回所有节点类型列表 |
edge_types | 返回所有边类型列表 |
metadata() | 返回(node_types, edge_types)元组 |
collect(key) | 收集所有类型中名为key的属性(如data.collect('x')) |
subgraph(subset_dict) | 根据节点子集提取子图 |
to(device) | 将所有张量移至指定设备 |
2.2.3 数据验证与调试
HeteroData提供内置验证机制,可自动检查图结构的一致性:
# 验证节点类型与边类型的引用一致性
data.validate(raise_on_error=True)
# 若存在未定义节点类型的边引用,将抛出 ValueError
3. 异构图神经网络构建
3.1 类型感知的图卷积层
PyG提供多种异构图卷积层,满足不同场景需求:
3.1.1 HeteroConv:灵活组合同构卷积
HeteroConv允许为每种边类型指定不同的卷积操作,自动处理类型路由:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv
conv = HeteroConv({
('author', 'writes', 'paper'): GCNConv(-1, 64),
('paper', 'cites', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'published_in', 'conference'): GATConv((-1, -1), 64),
}, aggr='sum') # 聚合不同类型边的输出
# 前向传播
x_dict = {
'paper': data['paper'].x,
'author': data['author'].x,
'conference': data['conference'].x,
}
edge_index_dict = {
('author', 'writes', 'paper'): data['author', 'writes', 'paper'].edge_index,
('paper', 'cites', 'paper'): data['paper', 'cites', 'paper'].edge_index,
('paper', 'published_in', 'conference'): data['paper', 'published_in', 'conference'].edge_index,
}
out_dict = conv(x_dict, edge_index_dict)
# out_dict包含所有节点类型的更新嵌入
print(out_dict.keys()) # dict_keys(['paper', 'author', 'conference'])
3.1.2 HGTConv:异构图Transformer
异构图Transformer(HGT) 是处理异构图的先进模型,通过以下机制捕捉类型信息:
- 元路径感知的邻居采样:针对不同节点类型采用特定采样策略
- 类型依赖的注意力:计算注意力时考虑源/目标节点类型
- 关系特定的消息函数:为每种边类型学习独立的消息传递函数
from torch_geometric.nn import HGTConv
hgt_conv = HGTConv(
in_channels=-1, # 输入特征维度(-1表示自动推断)
out_channels=128, # 输出特征维度
metadata=data.metadata(),# 图元数据 (node_types, edge_types)
num_heads=8, # 注意力头数
group='sum' # 多头注意力聚合方式
)
# 前向传播
hgt_out_dict = hgt_conv(x_dict, edge_index_dict)
3.2 从同构图模型到异构图模型
3.2.1 to_hetero:自动转换同构模型
对于已有的同构图模型,PyG提供to_hetero工具自动转换为异构图模型:
from torch_geometric.nn import GCNConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(-1, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
# 创建同构模型
homogeneous_model = GNN(hidden_channels=64, out_channels=4)
# 转换为异构图模型
heterogeneous_model = to_hetero(
homogeneous_model,
metadata=data.metadata(),
aggr='sum' # 多源消息聚合方式
)
# 异构图模型前向传播
out_dict = heterogeneous_model(x_dict, edge_index_dict)
3.2.2 自定义异构图模型
以DBLP学术网络的作者分类任务为例,构建完整HGT模型:
from torch_geometric.datasets import DBLP
import torch_geometric.transforms as T
# 加载DBLP数据集(包含4种节点类型和6种边类型)
dataset = DBLP(root='/tmp/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0].to('cuda' if torch.cuda.is_available() else 'cpu')
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
super().__init__()
# 节点类型嵌入层
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
self.lin_dict[node_type] = torch.nn.Linear(-1, hidden_channels)
# HGT卷积层堆叠
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads)
self.convs.append(conv)
# 分类头
self.lin = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
# 输入特征线性变换
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
# 多层HGT卷积
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
# 作者节点分类
return self.lin(x_dict['author'])
# 初始化模型
model = HGT(
hidden_channels=64,
out_channels=dataset.num_classes,
num_heads=2,
num_layers=2
).to(data.x_dict.values().__iter__().__next__().device)
4. 异构图数据加载与采样
4.1 HGTLoader:面向大型异构图的节点采样
HGTLoader实现了异构图上的高效节点采样,通过以下策略优化性能:
- 类型感知采样:为每种节点类型维护独立的采样预算
- 元路径引导:根据节点类型间的连接模式调整采样概率
- 批次平衡:确保每个批次包含均衡的节点类型分布
from torch_geometric.loader import HGTLoader
# 定义每种节点类型的采样深度和数量
num_samples = {
'paper': [15, 10, 5], # 3层采样,每层采样15/10/5个邻居
'author': [10, 10, 5],
'conference': [5, 5],
'term': [5, 5]
}
# 创建HGT数据加载器
loader = HGTLoader(
data,
num_samples=num_samples,
batch_size=256, # 批次大小
input_nodes=('author', data['author'].train_mask), # 训练节点
shuffle=True,
num_workers=4
)
# 迭代获取批次数据
for batch in loader:
print(batch)
# HeteroData(
# paper={ x=[1200, 128], batch_size=... },
# author={ x=[256, 64], y=[256], train_mask=[256] },
# ...
# )
out = model(batch.x_dict, batch.edge_index_dict)
loss = F.cross_entropy(out[batch['author'].train_mask], batch['author'].y[batch['author'].train_mask])
4.2 链接预测的数据加载
针对异构图链接预测任务,可使用LinkNeighborLoader进行边采样:
from torch_geometric.loader import LinkNeighborLoader
# 边采样配置
loader = LinkNeighborLoader(
data,
num_neighbors=[5, 5], # 2层邻居采样
edge_label_index=(('user', 'rates', 'movie'), data['user', 'rates', 'movie'].edge_label_index),
edge_label=data['user', 'rates', 'movie'].edge_label,
neg_sampling=dict(mode='binary', amount=1), # 负采样配置
batch_size=1024,
shuffle=True
)
5. 实战案例:电影推荐系统中的异构图学习
5.1 数据建模
将MovieLens数据集构建为异构图,包含用户(user)和电影(movie)两种节点类型,以及评分(rates)边类型:
from torch_geometric.datasets import MovieLens
import torch_geometric.transforms as T
# 加载MovieLens-1M数据集
dataset = MovieLens(
root='/tmp/MovieLens',
model_name='all-MiniLM-L6-v2', # 使用预训练句子嵌入作为电影特征
transform=T.ToUndirected() # 添加反向边
)
data = dataset[0]
# 用户节点特征构造(单位矩阵)
data['user'].x = torch.eye(data['user'].num_nodes, device=data['user'].x.device)
5.2 模型架构
构建基于异构图SAGE的推荐模型,包含编码器和解码器:
from torch_geometric.nn import SAGEConv, to_hetero
class GNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), hidden_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
self.lin2 = torch.nn.Linear(hidden_channels, 1)
def forward(self, z_dict, edge_label_index):
# 拼接用户和电影嵌入
row, col = edge_label_index
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
z = self.lin1(z).relu()
return self.lin2(z).view(-1)
class RecommendationModel(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.encoder = GNNEncoder(hidden_channels)
# 转换为异构图模型
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
self.decoder = EdgeDecoder(hidden_channels)
def forward(self, x_dict, edge_index_dict, edge_label_index):
z_dict = self.encoder(x_dict, edge_index_dict)
return self.decoder(z_dict, edge_label_index)
5.3 训练与评估
# 数据拆分
train_data, val_data, test_data = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
neg_sampling_ratio=0.0,
edge_types=[('user', 'rates', 'movie')],
rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)
# 模型初始化
model = RecommendationModel(hidden_channels=32).to(data.x_dict['user'].device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练函数
def train():
model.train()
optimizer.zero_grad()
pred = model(
train_data.x_dict,
train_data.edge_index_dict,
train_data['user', 'rates', 'movie'].edge_label_index
)
target = train_data['user', 'rates', 'movie'].edge_label
loss = F.mse_loss(pred, target.float())
loss.backward()
optimizer.step()
return float(loss)
# 测试函数
@torch.no_grad()
def test(data):
model.eval()
pred = model(
data.x_dict,
data.edge_index_dict,
data['user', 'rates', 'movie'].edge_label_index
)
return F.mse_loss(pred, data['user', 'rates', 'movie'].edge_label.float()).sqrt()
# 训练循环
for epoch in range(1, 201):
loss = train()
train_rmse = test(train_data)
val_rmse = test(val_data)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train RMSE: {train_rmse:.4f}, Val RMSE: {val_rmse:.4f}')
6. 高级应用与最佳实践
6.1 异构图可视化
使用PyG的可视化工具展示异构图结构:
from torch_geometric.visualization import draw_hetero_graph
# 绘制异构图(简化版)
draw_hetero_graph(
data.edge_index_dict,
data.metadata(),
node_color={
'author': 'skyblue',
'paper': 'lightgreen',
'conference': 'orange'
},
node_size=1000,
font_size=8,
edge_color='gray',
alpha=0.5
)
6.2
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



