SumPooling
import dgl
import torch
from dgl.nn.pytorch.glob import SumPooling
# 创建一个示例图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
# 假设每个节点有一个特征向量
node_features = torch.randn(3, 10) # 3个节点,每个节点有10维特征
g.ndata['h'] = node_features
# 使用SumPooling进行全局池化
sum_pooling = SumPooling()
global_feature = sum_pooling(g, g.ndata['h'])
print(g.ndata['h'])
print("Global poo

最低0.47元/天 解锁文章
794

被折叠的 条评论
为什么被折叠?



