【知识图谱系列】GCNII模型探索DeepGNN的Over-Smoothing问题

GCNII解决Over-Smoothing问题

作者:CHEONG

公众号:AI机器学习与知识图谱

研究方向:自然语言处理与知识图谱

GCNII (ICML 2020) 分享,GCNII全称:Graph Convolutional Networks via Initial residual and Identity Mapping

GCNII汇报ppt版可通过关注公众号【AI机器学习与知识图谱】,回复关键词:GCNII 来获得,供学习者使用!可添加微信号【17865190919】进学习交流群,加好友时备注来自优快云。原创不易,转载请告知并注明出处!


一、Motivation

在计算机视觉中,模型CNN随着其层次加深可以学习到更深层次的特征信息,叠加64层或128层是十分正常的现象,且能较浅层取得更优的效果。

图卷积神经网络GCNs是一种针对图结构数据的深度学习方法,但目前大多数的GCN模型都是浅层的,如GCN,GAT模型都是在2层时取得最优效果,随着加深模型效果就会大幅度下降,经研究GCN随着模型层次加深会出现Over-Smoothing问题,Over-Smoothing既相邻的节点随着网络变深就会越来越相似,最后学习到的nodeembedding便无法区分。

图片

上图中,随着模型层次加深,在Cora数据上Test Accuracy逐渐向下降,Quantitative Metric for Smoothness给Over-smoothness提出一个定量的指标SVMGSVM_GSVMG,如下公式所示:

图片
图片
图片

SVMGSVM_GSVMG衡量了图中任意两个节点之间的欧氏距离之和,SVMGSVM_GSVMG越小表示图学习时Over-Smoothing越严重当,当SVMG=0SVM_G=0SVMG=0时,图中所有节点完全相同,也可以从图中看出随着层次的加深,SVMGSVM_GSVMG的值越来越小。



二、Method

GCNII为了解决GCN在深层时出现的Over-Smoothing问题,提出了Initial ResidualIdentit Mapping两个简单技巧,成功解决了GCN深层时的Over-Smoothing问题。

1、Initial residual

残差一直是解决Over-Smoothing的最常用的技巧之一,传统GCN加residualconnection用公式表示为:

图片

GCNII Initial Residual不是从前一层获取信息,而是从初始层进行残差连接,并且设置了获取的权重。这里初始层initial representation不是原始输入feature,而是由输入feature经过线性变换后得到,如下公式所示:

图片
图片

但Initial Residual不是GCNII首次提出,而是ICLR 2019模型APPNP中提出。

图片

2、Identity Mapping

仅仅使用残差只能缓解Over-Smoothing问题,因此GCNII借鉴了ResNet的思想有了Identity Mapping,Initial Residual的想法是在当前层representation和初始层representation之间进行权重选择,而Identity Mapping是在参数W和单位矩阵I之间设置权重选择,如下公式所示:

图片

从上面公式看出,前半部分是Initialresidual,后半部分是IdentityMapping,其中α和β是超参,GCNII论文中也给出了为什么IdentityMapping可以起到缓解DeepGNN出现Over-Smoothing问题,总结来说:IdentityMapping可以起到加快模型的收敛速度,减少有效信息的损失。



三、Conclusion

1、实验数据

图片

实验中Cora, Citeseer, Pubmed三个引文数据,是同质图数据,常用于Transductive Learning类任务,三种数据都由以下八个文件组成,存储格式类似:

2、实验结果

实验结果在Cora, citeseer, pubmed三个数据上都进行DeepGNN测试,测试结果可以看出随着网络层级的加深,模型不仅没有像传统GNN出现Over-Smoothing而效果下降,反而模型效果随着深度增加而不断提升,解决了传统DeepGNN存在的Over-Smoothing问题。

图片

GCNII汇报ppt版可通过关注公众号【AI机器学习与知识图谱】,回复关键词:GCNII 来获得,供学习者使用!有用的话就点个赞呗!

### 概述 在图神经网络(GNN)驱动的社交推荐系统中,过平滑(over-smoothing问题是一个关键挑战,表现为随着网络层数的增加,节点表示逐渐趋同,失去区分性。为缓解这一问题,可以探索将序列模型(sequential models)集成到GNN框架中。序列模型擅长捕捉时间依赖性和动态变化,能够提供额外的上下文信息,从而增强节点表示的表达能力,减少过平滑效应。 ### 序列模型与GNN的结合方式 1. **时序信息建模** 在社交推荐系统中,用户的行为和互动具有明显的时序特性。通过引入序列模型(如RNN、LSTM或Transformer),可以捕捉用户的动态行为模式。这些模型能够对用户的交互序列进行建模,提取出时间维度上的特征,并将其与GNN生成的结构特征相结合。例如,可以将用户的点击、浏览、购买等行为序列输入到LSTM中,生成用户的时序嵌入,再与GNN输出的用户嵌入进行融合[^1]。 2. **多模态特征融合** 序列模型不仅可以处理时间序列数据,还可以与其他类型的输入结合,形成多模态特征表示。在社交推荐系统中,用户可能同时具有文本、图像等多种形式的交互数据。通过将这些数据分别输入到不同的序列模型中,提取各自的特征,再与GNN生成的图结构特征进行融合,能够增强节点表示的丰富性和准确性[^2]。 3. **层次化建模** 可以采用层次化的方式,先使用GNN对社交网络结构进行建模,得到初步的节点表示;然后,将这些节点表示作为输入,送入序列模型中进行进一步的优化。这种层次化的建模方式能够有效缓解过平滑问题,因为序列模型可以在更高层次上对节点表示进行调整和细化,保留更多的局部特征[^3]。 4. **注意力机制的引入** 序列模型中的注意力机制(如Transformer中的自注意力机制)可以用于动态调整不同邻居节点对目标节点的影响权重。在传统的GNN中,所有邻居节点对目标节点的影响是等同的,这可能导致信息混淆。而通过引入注意力机制,可以根据节点之间的相似性或相关性,自动调整权重分配,从而提高节点表示的质量。例如,可以设计一个基于Transformer的GNN模型,在聚合邻居信息时引入注意力机制,使得每个节点能够更精准地选择对其最有帮助的邻居节点[^4]。 5. **递归更新策略** 序列模型的递归特性可以用于设计递归更新策略,逐步优化节点表示。具体来说,可以在每一轮迭代中,使用序列模型对当前的节点表示进行更新,逐步逼近更优的表示。这种递归更新策略可以避免一次性更新带来的信息丢失,同时也有助于缓解过平滑问题。例如,可以设计一个基于RNN的更新模块,每次迭代时根据前一次的节点表示生成新的表示,逐步提升模型的性能。 6. **对抗训练与正则化** 为了进一步缓解过平滑问题,可以在模型中引入对抗训练或正则化策略。对抗训练可以通过生成对抗样本,增强模型对噪声和异常值的鲁棒性;而正则化策略则可以通过约束节点表示的分布,防止其过于趋同。例如,可以在损失函数中加入一个正则化项,鼓励相邻节点的表示保持一定的差异性,从而避免过平滑现象的发生。 ### 实验与评估 在实际应用中,可以通过实验验证不同序列模型与GNN结合的效果。常见的评估指标包括准确率、召回率、F1分数等。此外,还可以通过可视化节点表示的方式,直观地观察不同模型在缓解过平滑问题上的表现。实验结果表明,引入序列模型可以显著提升GNN在社交推荐系统中的性能,尤其是在处理大规模、高密度的社交网络时,效果更为明显。 ### 示例代码 以下是一个简单的示例代码,展示了如何将LSTM与GNN结合,用于社交推荐系统的节点表示学习: ```python import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv class GNNWithLSTM(nn.Module): def __init__(self, num_features, hidden_dim, num_classes): super(GNNWithLSTM, self).__init__() self.gcn = GCNConv(num_features, hidden_dim) self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True) self.classifier = nn.Linear(hidden_dim, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index # GNN部分 x_gcn = F.relu(self.gcn(x, edge_index)) # 将GNN输出作为LSTM输入 x_lstm, _ = self.lstm(x_gcn.unsqueeze(1)) # 分类 out = self.classifier(x_lstm.squeeze(1)) return out ``` 在该示例中,`GCNConv`用于提取图结构特征,`LSTM`则用于捕捉时序信息。最终的输出通过一个分类器进行预测。 ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值