【dgl框架函数学习】with graph.local_scope()

本文详细介绍了DGL库中local_scope函数的作用,如何在自定义卷积和模型计算中防止特征值意外修改。通过实例展示了如何在函数中临时操作图特征并保持原图不变,以及为何在位运算下会反映到原图。

函数介绍

  在使用dgl框架自定义卷积算子或者看dgl实现函数的源代码时,总会碰到如下的代码形式:

 with graph.local_scope():
      if isinstance(feat, tuple):
           feat, _ = feat  # dst feature not used
       cj = graph.srcdata['cj']
       ci = graph.dstdata['ci']
       if self.device is not None:
           cj = cj.to(self.device)
           ci = ci.to(self.device)
with graph.local_scope()

函数到底起到什么作用,这里做一下介绍,graph使用dgl定义的一个图,官方的介绍如下:
Enter a local scope context for the graph.

By entering a local scope, any out-place mutation to the feature data will not reflect to the original graph, thus making it easier to use in a function scope (e.g. forward computation of a model).

If set, the local scope will use same initializers for node features and edge features.

import dgl.data import os import torch import torch.nn.functional as F import dgl.nn.pytorch as dglnn import torch.nn as nn from mmpose.apis.inferencers import MMPoseInferencer class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(Classifier, self).__init__() self.conv1 = dglnn.GraphConv(in_dim, hidden_dim, allow_zero_in_degree=True) self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim, allow_zero_in_degree=True) self.classify = nn.Linear(hidden_dim, n_classes) def forward(self, g, h): # 应用图卷积和激活函数 h = F.relu(self.conv1(g, h)) h = F.relu(self.conv2(g, h)) with g.local_scope(): g.ndata['h'] = h # 使用平均读出计算图表示 hg = dgl.mean_nodes(g, 'h') return self.classify(hg) file_path = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 关键点检测模型加载 init_args = {'pose2d': os.path.join(file_path, 'key_point_model/configs/hand_2d_keypoint/rtmpose/hand5/rtmpose-m_8xb256-210e_hand5-256x256.py'), 'pose2d_weights': os.path.join(file_path, 'key_point_model/rtmpose-m_simcc-hand5_pt-aic-coco_210e-256x256-74fb594_20230320.pth'), 'scope': 'mmpose', 'device': None, 'det_model': 'whole-image', 'det_weights': None, 'det_cat_ids': 0, 'pose3d': None, 'pose3d_weights': None, 'show_progress': False} inferencer = MMPoseInferencer(**init_args) # 分类模型加载 model_path = os.path.join(file_path, 'cls_model/GCN.pt') model = torch.load(model_path) model = model.to(device) model.eval() data_path = os.path.join(file_path, 'dataset/val/') # 读取手势图片 files = [f for f in os.listdir(data_path)] for filename in files: img_path = os.path.join(data_path, filename) call_args = {'inputs': img_path, 'show': False, 'draw_bbox': False, 'draw_heatmap': False, 'bbox_thr': 0.3, 'nms_thr': 0.3, 'pose_based_nms': False, 'kpt_thr': 0
最新发布
03-20
评论 4
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值