【dgl框架函数学习】with graph.local_scope()

本文详细介绍了DGL库中local_scope函数的作用,如何在自定义卷积和模型计算中防止特征值意外修改。通过实例展示了如何在函数中临时操作图特征并保持原图不变,以及为何在位运算下会反映到原图。

函数介绍

  在使用dgl框架自定义卷积算子或者看dgl实现函数的源代码时,总会碰到如下的代码形式:

 with graph.local_scope():
      if isinstance(feat, tuple):
           feat, _ = feat  # dst feature not used
       cj = graph.srcdata['cj']
       ci = graph.dstdata['ci']
       if self.device is not None:
           cj = cj.to(self.device)
           ci = ci.to(self.device)
with graph.local_scope()

函数到底起到什么作用,这里做一下介绍,graph使用dgl定义的一个图,官方的介绍如下:
Enter a local scope context for the graph.

By entering a local scope, any out-place mutation to the feature data will not reflect to the original graph, thus making it easier to use in a function scope (e.g. forward computation of a model).

If set, the local scope will use same initializers for node features and edge features.

import dgl.data import os import torch import torch.nn.functional as F import dgl.nn.pytorch as dglnn import torch.nn as nn from mmpose.apis.inferencers import MMPoseInferencer class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(Classifier, self).__init__() self.conv1 = dglnn.GraphConv(in_dim, hidden_dim, allow_zero_in_degree=True) self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim, allow_zero_in_degree=True) self.classify = nn.Linear(hidden_dim, n_classes) def forward(self, g, h): # 应用图卷积和激活函数 h = F.relu(self.conv1(g, h)) h = F.relu(self.conv2(g, h)) with g.local_scope(): g.ndata['h'] = h # 使用平均读出计算图表示 hg = dgl.mean_nodes(g, 'h') return self.classify(hg) file_path = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 关键点检测模型加载 init_args = {'pose2d': os.path.join(file_path, 'key_point_model/configs/hand_2d_keypoint/rtmpose/hand5/rtmpose-m_8xb256-210e_hand5-256x256.py'), 'pose2d_weights': os.path.join(file_path, 'key_point_model/rtmpose-m_simcc-hand5_pt-aic-coco_210e-256x256-74fb594_20230320.pth'), 'scope': 'mmpose', 'device': None, 'det_model': 'whole-image', 'det_weights': None, 'det_cat_ids': 0, 'pose3d': None, 'pose3d_weights': None, 'show_progress': False} inferencer = MMPoseInferencer(**init_args) # 分类模型加载 model_path = os.path.join(file_path, 'cls_model/GCN.pt') model = torch.load(model_path) model = model.to(device) model.eval() data_path = os.path.join(file_path, 'dataset/val/') # 读取手势图片 files = [f for f in os.listdir(data_path)] for filename in files: img_path = os.path.join(data_path, filename) call_args = {'inputs': img_path, 'show': False, 'draw_bbox': False, 'draw_heatmap': False, 'bbox_thr': 0.3, 'nms_thr': 0.3, 'pose_based_nms': False, 'kpt_thr': 0
最新发布
03-20
<think>好的,我需要帮助用户实现DGL GCN与MMPose手势关键点检测模型的联合推理流程。首先,我得明确用户的需求。他们可能正在构建一个手势分类系统,利用图神经网络(GCN)处理关键点数据,并结合MMPose进行关键点检测。接下来,我需要考虑如何将这两个部分结合起来。 首先,用户可能需要先通过MMPoseInferencer检测手势的关键点。这涉及到加载预训练模型和处理输入图像,得到关键点的坐标。这部分我应该建议用户如何初始化MMPose的推理器,处理图像,并提取关键点数据。 然后,关键点数据需要转换成图结构,供GCN处理。这里可能需要将每个关键点作为图的节点,关键点之间的连接作为边。比如,人体手部关键点通常有预定义的连接关系,用户可能需要根据这些关系构建图的边。DGL中可以使用dgl.graph来创建图,并指定节点和边。 接下来是GCN模型的构建。用户需要定义一个基于DGL的GCN模型,可能使用GraphConv层。需要说明如何定义消息传递、聚合和更新函数,或者直接使用DGL提供的图卷积层。还要考虑输入特征的维度,比如关键点的坐标(x, y)可能作为节点特征,或者加入其他信息如置信度。 然后是如何整合两个模型的推理流程。可能需要将MMPose的输出转换为DGL图的格式,然后输入GCN模型进行分类。这里需要注意数据格式的转换,比如将关键点坐标从NumPy数组转为PyTorch张量,并确保维度匹配。 另外,可能涉及到的性能优化,比如批处理或GPU加速。需要提醒用户数据处理和模型加载时的设备设置,如使用CUDA。此外,预处理和后处理的步骤也很重要,比如归一化关键点坐标或对GCN的输出进行softmax处理得到分类结果。 还需要考虑异常情况,比如MMPose没有检测到关键点时如何处理,或者图结构构建不正确导致的错误。建议用户添加必要的检查,比如判断关键点是否存在,或者边索引是否正确。 最后,引用相关的参考资料,比如DGL的官方文档和教程,以及MMPose的使用示例,确保用户能够进一步查阅详细信息。可能还需要提到如何结合半监督学习,如果用户的数据标注有限的话,可以参考Kipf等人的方法。 总结步骤应该是:1.使用MMPose检测关键点;2.构建图结构;3.定义GCN模型;4.联合推理流程;5.预处理和后处理;6.性能优化和异常处理。同时,提供代码示例帮助用户具体实现。</think>以下是基于DGL框架实现GCN与MMPose手势关键点检测模型联合推理的分步解决方案: ### 一、系统架构设计 1. **MMPose关键点检测**:使用MMPoseInferencer提取手势关键点坐标和置信度 2. **图结构构建**:将关键点转换为图结构,节点特征包含坐标$(x,y)$和置信度$c$,边连接遵循人体手部拓扑 3. **GCN模型定义**:基于DGLGraphConv层构建图卷积网络 4. **联合推理流程**:建立从关键点检测到图分类的端到端pipeline ### 二、关键实现步骤 #### 1. MMPose关键点检测 ```python from mmpose.apis import MMPoseInferencer # 初始化检测器 inferencer = MMPoseInferencer( pose2d='hand', # 手势关键点检测模型 device='cuda:0' # GPU加速 ) # 执行推理 result = next(inferencer('input.jpg', show=False)) # 提取关键点数据 (shape: [21,3]) keypoints = result['predictions'][0][0]['keypoints'] # 21个手部关键点 ``` #### 2. 构建图结构 ```python import dgl import torch # 定义节点特征 (shape: [21,3]) node_features = torch.FloatTensor(keypoints) # [x, y, confidence] # 定义边连接关系(示例手部拓扑) edges = [ (0,1),(1,2),(2,3),(3,4), # 拇指 (0,5),(5,6),(6,7),(7,8), # 食指 (0,9),(9,10),(10,11),(11,12), # 中指 (0,13),(13,14),(14,15),(15,16),# 无名指 (0,17),(17,18),(18,19),(19,20) # 小指 ] # 创建DGL图 g = dgl.graph(([u for u,_ in edges], [v for _,v in edges])) g.ndata['feat'] = node_features ``` #### 3. GCN模型定义 ```python from dgl.nn import GraphConv class GCNClassifier(nn.Module): def __init__(self, in_dim, hidden_dim, num_classes): super().__init__() self.conv1 = GraphConv(in_dim, hidden_dim) self.conv2 = GraphConv(hidden_dim, hidden_dim) self.classifier = nn.Linear(hidden_dim, num_classes) def forward(self, g): h = F.relu(self.conv1(g, g.ndata['feat'])) h = F.relu(self.conv2(g, h)) with g.local_scope(): g.ndata['h'] = h hg = dgl.mean_nodes(g, 'h') # 图级特征 return self.classifier(hg) ``` #### 4. 联合推理流程 ```python def joint_inference(img_path): # 步骤1:关键点检测 pose_result = next(inferencer(img_path, show=False)) # 步骤2:构建图数据 g = construct_graph(pose_result) # 包含上述图构建逻辑 # 步骤3:GCN推理 model = load_trained_gcn() # 加载预训练模型 logits = model(g) # 后处理 return F.softmax(logits, dim=1) ``` ### 三、关键优化点 1. **特征增强**:在节点特征中加入相对坐标偏移量$\Delta x, \Delta y$,计算方式为: $$(\Delta x_i, \Delta y_i) = (x_i - x_{root}, y_i - y_{root})$$ 2. **动态图构建**:根据置信度自动过滤低质量检测结果,调整边连接策略 3. **多尺度特征融合**:在GCN中增加跳跃连接: $$h^{(l+1)} = \sigma(\hat{D}^{-1/2}\hat{A}\hat{D}^{-1/2}h^{(l)}W^{(l)}) + h^{(l)}$$ ### 四、常见问题处理 1. **关键点缺失**:当MMPose检测失败时,可采用滑动窗口平均法进行插值补偿 2. **图结构异常**:通过DGL的异常检测机制验证边索引有效性: ```python assert torch.all(g.edges()[0] < g.num_nodes()), "Invalid edge source" ``` 3. **设备兼容性**:统一数据设备类型 ```python g = g.to('cuda:0') if torch.cuda.is_available() else g ``` [^1]: DGL官方文档提供了图神经网络的实现范式 [^2]: 图卷积层的自定义实现需要正确处理消息传递机制 [^3]: GCN的数学形式化描述为$H^{(l+1)} = \sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{(l)}W^{(l)})$ [^4]: 半监督节点分类的实现方法可参考Cora数据集案例
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值