PyG搭建R-GCN实现节点分类

前言

R-GCN的原理请见:ESWC 2018 | R-GCN:基于图卷积网络的关系数据建模

数据处理

导入数据:

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP')
dataset = DBLP(path)
graph = dataset[0]
print(graph)

输出如下:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={ num_nodes=20 },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)

可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。

任务:对author节点进行分类,一共4类。

由于conference节点没有特征,因此需要预先设置特征:

graph['conference'].x = torch.randn((graph['conference'].num_nodes, 50))

所有conference节点的特征都随机初始化。

获取一些有用的数据:

num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y

node_types, edge_types = graph.metadata()
num_nodes = graph['author'].x.shape[0]
num_relations = len(edge_types)
init_sizes = [graph[x].x.shape[1] for x in node_types]
homogeneous_graph = graph.to_homogeneous()
in_feats, hidden_feats = 128, 64

模型搭建

首先导入包:

from torch_geometric.nn import RGCNConv

模型参数:
在这里插入图片描述

  1. in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1。
  2. out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
  3. num_relations:关系数。
  4. num_bases:如果使用基函数分解正则化,则其表示要使用的基数。
  5. num_blocks:如果使用块对角分解正则化,则其表示要使用的块数。
  6. aggr:聚合方式,默认为mean

于是模型搭建如下:

class RGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(RGCN, self).__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels,
                              num_relations=num_relations, num_bases=30)
        self.conv2 = RGCNConv(hidden_channels, out_channels,
                              num_relations=num_relations, num_bases=30)
        self.lins = torch.nn.ModuleList()
        for i in range(len(node_types)):
            lin = nn.Linear(init_sizes[i], in_channels)
            self.lins.append(lin)

    def trans_dimensions(self, g):
        data = copy.deepcopy(g)
        for node_type, lin in zip(node_types, self.lins):
            data[node_type].x = lin(data[node_type].x)

        return data

    def forward(self, data):
        data = self.trans_dimensions(data)
        homogeneous_data = data.to_homogeneous()
        edge_index, edge_type = homogeneous_data.edge_index, homogeneous_data.edge_type
        x = self.conv1(homogeneous_data.x, edge_index, edge_type)
        x = self.conv2(x, edge_index, edge_type)
        x = x[:num_nodes]
        return x

输出一下模型:

model = RGCN(in_feats, hidden_feats, num_classes).to(device)
RGCN(
  (conv1): RGCNConv(128, 64, num_relations=6)
  (conv2): RGCNConv(64, 4, num_relations=6)
  (lins): ModuleList(
    (0): Linear(in_features=334, out_features=128, bias=True)
    (1): Linear(in_features=4231, out_features=128, bias=True)
    (2): Linear(in_features=50, out_features=128, bias=True)
    (3): Linear(in_features=50, out_features=128, bias=True)
  )
)

1. 前向传播

查看官方文档中RGCNConv的输入输出要求:
在这里插入图片描述
可以发现,RGCNConv中需要输入的是节点特征x、边索引edge_index以及边类型edge_type

我们输出初始化特征后的DBLP数据集:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 50]
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)

可以发现,DBLP中并没有上述要求的三个值。因此,我们首先需要将其转为同质图:

homogeneous_graph = graph.to_homogeneous()
Data(node_type=[26128], edge_index=[2, 239566], edge_type=[239566])

转为同质图后虽然有了edge_indexedge_type,但没有所有节点的特征x,这是因为在将异质图转为同质图的过程中,只有所有节点的特征维度相同时才能将所有节点的特征进行合并。因此,我们首先需要将所有节点的特征转换到同一维度(这里以128为例):

def trans_dimensions(self, g):
    data = copy.deepcopy(g)
    for node_type, lin in zip(node_types, self.lins):
        data[node_type].x = lin(data[node_type].x)

    return data

转换后的data中所有类型节点的特征维度都为128,然后再将其转为同质图:

data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
Data(node_type=[26128], x=[26128, 128], edge_index=[2, 239566], edge_type=[239566])

此时,我们就可以将homogeneous_data输入到RGCNConv中:

x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)

输出的x包含所有节点的信息,我们只需要取前4057个,也就是author节点的特征:

x = x[:num_nodes]

2. 反向传播

在训练时,我们首先利用前向传播计算出输出:

f = model(graph)

f即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:

loss = loss_function(f[train_mask], y[train_mask])

然后计算梯度,反向更新!

3. 训练

训练时返回验证集上表现最优的模型:

def train():
    model = RGCN(in_feats, hidden_feats, num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
                                 weight_decay=1e-4)
    loss_function = torch.nn.CrossEntropyLoss().to(device)
    min_epochs = 5
    best_val_acc = 0
    final_best_acc = 0
    model.train()
    for epoch in range(100):
        f = model(graph)
        loss = loss_function(f[train_mask], y[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # validation
        val_acc, val_loss = test(model, val_mask)
        test_acc, test_loss = test(model, test_mask)
        if epoch + 1 > min_epochs and val_acc > best_val_acc:
            best_val_acc = val_acc
            final_best_acc = test_acc
        print('Epoch{:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'.
              format(epoch, loss.item(), val_acc, test_acc))

    return final_best_acc

4. 测试

@torch.no_grad()
def test(model, mask):
    model.eval()
    out = model(graph)
    loss_function = torch.nn.CrossEntropyLoss().to(device)
    loss = loss_function(out[mask], y[mask])
    _, pred = out.max(dim=1)
    correct = int(pred[mask].eq(y[mask]).sum().item())
    acc = correct / int(test_mask.sum())

    return acc, loss.item()

实验结果

数据集采用DBLP网络,训练100轮,分类正确率为93.77%:

RGCN Accuracy: 0.9376727049431992

完整代码

代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!

### 关系图卷积网络 (R-GCN) 的原理 关系图卷积网络(R-GCN)扩展了传统的图卷积网络(GCN),以处理具有不同类型边的关系图。在传统 GCN 中,节点聚合来自邻居的信息而不区分这些连接的具体性质;而在 R-GCN 中,则考虑到了不同类型的边所代表的不同关系。 具体来说,在每层传播过程中,R-GCN 对于每一个特定类型 \(r\) 的关系都定义了一个单独的权重矩阵 \(\mathbf{W}_r^{(l)}\)[^1]。这意味着当更新某个节点表示时,会根据不同种类的关系分别计算加权求和后的特征向量,并最终通过某种方式组合起来形成新的节点嵌入: \[h_i^{(l+1)}=\sigma\left(\sum_{r \in \mathcal{R}} \sum_{j \in N_r(i)} \frac{1}{c_{i, r}} \mathbf{W}_{r}^{(l)} h_j^{(l)}+\mathbf{b}\right)\] 其中,\(N_r(i)\) 表示与节点 i 有第 r 类型关系相连的所有邻接点集合;而常数项 \(\frac{1}{c_{i,r}}\) 则用来防止梯度爆炸或者消失的问题[^2]。 ### 实现细节 为了有效地训练大规模稀疏图上的 R-GCN 模型,通常采用如下策略来简化上述公式: - **参数共享**:对于某些任务可以假设相同方向上的所有关系拥有相似的重要性,因此可以在不同的关系间共享同一个权重矩阵。 - **基滤波器方法**:利用一组基础变换矩阵去近似任意数量的关系转换操作,减少所需学习参数的数量。 此外,还可能涉及到其他技术如批标准化(batch normalization),dropout 正则化等以提高泛化能力[^3]。 ```python import dgl.nn.pytorch as dglnn import torch.nn.functional as F class RGNNLayer(nn.Module): def __init__(self, in_feat, out_feat, num_rels): super(RGNNLayer, self).__init__() self.conv = dglnn.RelGraphConv(in_feat, out_feat, num_rels) def forward(self, g, feat, etypes): return F.relu(self.conv(g, feat, etypes)) ``` 这段代码展示了如何使用 DGL 库构建一个简单的单层 R-GCN 层。`dglnn.RelGraphConv` 是专门为实现 R-GCN 设计的一个模块,它接受输入特征 `feat`, 边类型 `etypes` 和图结构 `g` 来执行消息传递过程。 ### 应用场景 R-GCN 已经被成功应用于多个领域内的实际问题解决当中,尤其是在涉及复杂关联模式的数据集上表现优异。例如,在知识库补全任务中,R-GCN 可帮助识别并填补已知事实间的空白部分——即所谓的“链接预测”。另一个典型应用场景是在实体分类方面,这里的目标是从给定的一系列属性描述中推断出未知类别标签。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值