【dgl框架函数学习】with graph.local_scope()

本文详细介绍了DGL库中local_scope函数的作用,如何在自定义卷积和模型计算中防止特征值意外修改。通过实例展示了如何在函数中临时操作图特征并保持原图不变,以及为何在位运算下会反映到原图。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

函数介绍

  在使用dgl框架自定义卷积算子或者看dgl实现函数的源代码时,总会碰到如下的代码形式:

 with graph.local_scope():
      if isinstance(feat, tuple):
           feat, _ = feat  # dst feature not used
       cj = graph.srcdata['cj']
       ci = graph.dstdata['ci']
       if self.device is not None:
           cj = cj.to(self.device)
           ci = ci.to(self.device)
with graph.local_scope()

函数到底起到什么作用,这里做一下介绍,graph使用dgl定义的一个图,官方的介绍如下:
Enter a local scope context for the graph.

By entering a local scope, any out-place mutation to the feature data will not reflect to the original graph, thus making it easier to use in a function scope (e.g. forward computation of a model).

If set, the local scope will use same initializers for node features and edge features.
Notes
Inplace operations do reflect to the original graph. This function also has little overhead when the number of feature tensors in this graph is small.
大致意思为:
  使用local_scope() 范围时,任何对节点或边的修改在脱离这个局部范围后将不会影响图中的原始特征值 。直白点就是:在这个范围中,你可以引用、修改图中的特征值 ,但是只是临时的,出了这个范围,一切恢复原样。这样做的目的是方便计算。毕竟我们在图上作消息传递和聚合,有时仅是为了计算值并不想改变原始图。

例子

下面看个例子:

import dgl
import torch

def foo(g):
    with g.local_scope():
        g.edata['h'] = torch.ones((g.num_edges(), 3))
        g.edata['h2'] = torch.ones((g.num_edges(), 3))
        return g.edata['h']


# local_scope avoids changing the graph features when exiting the function.
g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))
g.edata['h'] = torch.zeros((g.num_edges(), 3))
newh = foo(g)
print(g.edata['h'])  # still get tensor of all zeros
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
'h2' in g.edata      # new feature set in the function scope is not found
False

# 下面这里例子为什么会改变原图的特征还不太理解,大家理解的可以在评论区留言
# In-place operations will still reflect to the original graph.
def foo(g):
    with g.local_scope():
        # in-place operation
        g.edata['h'] += 1
        return g.edata['h']


g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))
g.edata['h'] = torch.zeros((g.num_edges(), 1))
newh = foo(g)
print(g.edata['h'])  # the result changes
tensor([[1.],
        [1.],
        [1.]])

参考文献

[1]dgl.DGLGraph.local_scope
[2]DGL中的local_scope

import dgl.data import os import torch import torch.nn.functional as F import dgl.nn.pytorch as dglnn import torch.nn as nn from mmpose.apis.inferencers import MMPoseInferencer class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(Classifier, self).__init__() self.conv1 = dglnn.GraphConv(in_dim, hidden_dim, allow_zero_in_degree=True) self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim, allow_zero_in_degree=True) self.classify = nn.Linear(hidden_dim, n_classes) def forward(self, g, h): # 应用图卷积和激活函数 h = F.relu(self.conv1(g, h)) h = F.relu(self.conv2(g, h)) with g.local_scope(): g.ndata['h'] = h # 使用平均读出计算图表示 hg = dgl.mean_nodes(g, 'h') return self.classify(hg) file_path = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 关键点检测模型加载 init_args = {'pose2d': os.path.join(file_path, 'key_point_model/configs/hand_2d_keypoint/rtmpose/hand5/rtmpose-m_8xb256-210e_hand5-256x256.py'), 'pose2d_weights': os.path.join(file_path, 'key_point_model/rtmpose-m_simcc-hand5_pt-aic-coco_210e-256x256-74fb594_20230320.pth'), 'scope': 'mmpose', 'device': None, 'det_model': 'whole-image', 'det_weights': None, 'det_cat_ids': 0, 'pose3d': None, 'pose3d_weights': None, 'show_progress': False} inferencer = MMPoseInferencer(**init_args) # 分类模型加载 model_path = os.path.join(file_path, 'cls_model/GCN.pt') model = torch.load(model_path) model = model.to(device) model.eval() data_path = os.path.join(file_path, 'dataset/val/') # 读取手势图片 files = [f for f in os.listdir(data_path)] for filename in files: img_path = os.path.join(data_path, filename) call_args = {'inputs': img_path, 'show': False, 'draw_bbox': False, 'draw_heatmap': False, 'bbox_thr': 0.3, 'nms_thr': 0.3, 'pose_based_nms': False, 'kpt_thr': 0
最新发布
03-20
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值