import os
import tempfile
import requests
import cv2
import numpy as np
from mmpose.apis import init_pose_model, inference_top_down_pose_model
class MMPose:
def __init__(self, backbone='SCNet'):
# 由于SCNet可能不可用,我们使用HRNet作为替代
if backbone == 'SCNet':
print("注意:SCNet暂不可用,已自动替换为HRNet")
backbone = 'hrnet'
self.backbone = backbone
self.model = None
self.initialize_model()
def initialize_model(self):
"""初始化姿态估计模型"""
# 创建临时配置文件 - 使用与预训练权重完全匹配的配置
with tempfile.NamedTemporaryFile(suffix='.py', delete=False, mode='w') as temp_file:
temp_file.write("""
# HRNet 标准配置
model = dict(
type='TopDown',
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4,),
num_channels=(64,)),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(48, 96)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(48, 96, 192)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(48, 96, 192, 384))),
),
keypoint_head=dict(
type='TopDownSimpleHead',
in_channels=48,
out_channels=17,
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=True,
modulate_kernel=11))
data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=17,
num_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])
""")
config_path = temp_file.name
# 使用临时配置文件初始化模型
checkpoint = 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
self.model = init_pose_model(
config_path,
checkpoint,
device='cpu'
)
# 清理临时文件
os.unlink(config_path)
print("模型初始化成功!")
def inference(self, img, device='cpu', show=False):
"""执行姿态估计推理"""
# 确保模型已初始化
if self.model is None:
self.initialize_model()
# 检查图片是否存在
if not os.path.exists(img):
print(f"错误:图片文件 {img} 不存在!")
print("正在下载示例图片...")
self.download_example_image(img)
# 执行推理 - 确保使用正确的边界框格式
pose_results, _ = inference_top_down_pose_model(
self.model,
img,
bbox_thr=0.3,
format='xyxy', # 使用xyxy格式避免格式错误
dataset='coco',
return_heatmap=False
)
return pose_results
def download_example_image(self, save_path):
"""下载示例图片"""
url = "https://raw.githubusercontent.com/open-mmlab/mmpose/main/tests/data/coco/000000000785.jpg"
response = requests.get(url)
if response.status_code == 200:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
f.write(response.content)
print(f"示例图片已下载至: {save_path}")
else:
print(f"无法下载示例图片,状态码: {response.status_code}")
print("请手动提供图片文件")
# 原始代码保持不变
img_path = '/home/jovyan/1.jpg' # 指定进行推理的图片路径
model = MMPose(backbone='SCNet') # 实例化mmpose模型
result = model.inference(img=img_path, device='cpu', show=False) # 在CPU上进行推理
print(result)报错:注意:SCNet暂不可用,已自动替换为HRNet
load checkpoint from http path: https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth
The model and loaded state dict do not match exactly
size mismatch for keypoint_head.final_layer.weight: copying a param with shape torch.Size([17, 48, 1, 1]) from checkpoint, the shape in current model is torch.Size([17, 256, 1, 1]).
missing keys in source state_dict: keypoint_head.deconv_layers.0.weight, keypoint_head.deconv_layers.1.weight, keypoint_head.deconv_layers.1.bias, keypoint_head.deconv_layers.1.running_mean, keypoint_head.deconv_layers.1.running_var, keypoint_head.deconv_layers.3.weight, keypoint_head.deconv_layers.4.weight, keypoint_head.deconv_layers.4.bias, keypoint_head.deconv_layers.4.running_mean, keypoint_head.deconv_layers.4.running_var, keypoint_head.deconv_layers.6.weight, keypoint_head.deconv_layers.7.weight, keypoint_head.deconv_layers.7.bias, keypoint_head.deconv_layers.7.running_mean, keypoint_head.deconv_layers.7.running_var
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[12], line 135
133 img_path = '/home/jovyan/1.jpg' # 指定进行推理的图片路径
134 model = MMPose(backbone='SCNet') # 实例化mmpose模型
--> 135 result = model.inference(img=img_path, device='cpu', show=False) # 在CPU上进行推理
136 print(result)
Cell In[12], line 107, in MMPose.inference(self, img, device, show)
104 self.download_example_image(img)
106 # 执行推理 - 确保使用正确的边界框格式
--> 107 pose_results, _ = inference_top_down_pose_model(
108 self.model,
109 img,
110 bbox_thr=0.3,
111 format='xyxy', # 使用xyxy格式避免格式错误
112 dataset='coco',
113 return_heatmap=False
114 )
116 return pose_results
File /opt/conda/envs/mmedu/lib/python3.8/site-packages/mmpose/apis/inference.py:422, in inference_top_down_pose_model(model, img_or_path, person_results, bbox_thr, format, dataset, dataset_info, return_heatmap, outputs)
420 # Select bboxes by score threshold
421 if bbox_thr is not None:
--> 422 assert bboxes.shape[1] == 5
423 valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
424 bboxes = bboxes[valid_idx]
AssertionError:
最新发布