### 关于人体姿态估计的代码实现
#### Stacked Hourglass Networks 实现
Stacked Hourglass Networks 是一种用于人体姿态估计的经典网络结构,其核心思想是通过堆叠多个 hourglass 模块来逐步细化特征图并预测关键点位置。以下是该模型的一个简单实现框架:
```python
import torch
import torch.nn as nn
class Residual(nn.Module):
def __init__(self, num_in, num_out):
super(Residual, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.bn = nn.BatchNorm2d(num_in)
self.conv1 = nn.Conv2d(num_in, int(num_out / 2), bias=True, kernel_size=1)
self.conv2 = nn.Conv2d(int(num_out / 2), int(num_out / 2), bias=True, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(int(num_out / 2), num_out, bias=True, kernel_size=1)
self.skip_layer = nn.Conv2d(num_in, num_out, bias=True, kernel_size=1)
if num_in == num_out:
self.need_skip = False
else:
self.need_skip = True
def forward(self, x):
residual = x
out = self.bn(x)
out = self.relu(out)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
if self.need_skip:
residual = self.skip_layer(residual)
out += residual
return out
def make_pool(layer, k):
layers = []
for _ in range(k):
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
layer = nn.Sequential(*layers)
return layer
class HourGlass(nn.Module):
def __init__(self, n, f, bn=None, increase=0):
super(HourGlass, self).__init__()
nf = f + increase
self.up1 = Residual(f, f)
self.pool1 = make_pool(f, n)
self.low1 = Residual(f, nf)
if n > 1:
self.low2 = HourGlass(n - 1, nf, bn=bn)
else:
self.low2 = Residual(nf, nf)
self.low3 = Residual(nf, f)
self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
def forward(self, x):
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
class StackedHourGlassModel(nn.Module):
def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0):
super(StackedHourGlassModel, self).__init__()
self.nstack = nstack
self.pre = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
self.hgs = nn.ModuleList([
nn.Sequential(
HourGlass(4, inp_dim, bn, increase),
) for i in range(nstack)])
self.features = nn.ModuleList([nn.Sequential(
Residual(inp_dim, inp_dim),
nn.Conv2d(inp_dim, inp_dim, kernel_size=1, bias=True),
nn.BatchNorm2d(inp_dim),
nn.ReLU(True)) for i in range(nstack)])
self.outs = nn.ModuleList([nn.Conv2d(inp_dim, oup_dim, kernel_size=1, bias=True) for i in range(nstack)])
self.merge_features = nn.ModuleList([nn.Conv2d(inp_dim, inp_dim, kernel_size=1, bias=True) for i in range(nstack - 1)])
self.merge_preds = nn.ModuleList([nn.Conv2d(oup_dim, inp_dim, kernel_size=1, bias=True) for i in range(nstack - 1)])
def forward(self, x):
x = self.pre(x)
combined_hm_preds = []
for i in range(self.nstack):
hg = self.hgs[i](x)
feature = self.features[i](hg)
preds = self.outs[i](feature)
combined_hm_preds.append(preds)
if i < self.nstack - 1:
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
return torch.stack(combined_hm_preds, dim=1)
```
上述代码展示了如何构建一个简单的 Stacked Hourglass Network 来完成人体姿态估计的任务[^1]。
---
#### Lightweight Human Pose Estimation (PyTorch) 实现
`lightweight-human-pose-estimation.pytorch` 提供了一个轻量级的人体姿态估计工具包,适用于资源受限环境下的实时应用。以下是一个基本的使用流程:
##### 安装依赖项
首先安装必要的库:
```bash
pip install torch torchvision opencv-python matplotlib numpy
```
##### 下载预训练权重文件
访问项目仓库下载预训练模型权重文件 `checkpoint_iter_370000.pth` 并保存到本地目录下[^2]。
##### 推理代码示例
下面是一段推理代码片段,展示如何加载模型并对单张图片进行姿态估计:
```python
import cv2
import numpy as np
from models.with_mobilenet import PoseEstimationWithMobileNet
from modules.keypoints import extract_keypoints, group_keypoints
from modules.load_state import load_state
from val import normalize, pad_width
# 加载模型
net = PoseEstimationWithMobileNet()
checkpoint_path = 'checkpoint_iter_370000.pth'
load_state(net, checkpoint_path)
net.eval()
# 图像预处理函数
def infer_fast(img, net, height_size=256, stride=8, upsample_ratio=4):
img_resized, _, _ = resize_image(img, height_size)
tensor_img = normalize(img_resized)[None]
stages_output = net(tensor_img)
stage2_heatmaps = stages_output[-2]
heatmaps = np.transpose(stage2_heatmaps.squeeze().cpu().data.numpy(), (1, 2, 0))
heatmaps = cv2.resize(heatmaps, None, fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
stage2_pafs = stages_output[-1]
pafs = np.transpose(stage2_pafs.squeeze().cpu().data.numpy(), (1, 2, 0))
pafs = cv2.resize(pafs, None, fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
total_keypoints_num = 0
all_keypoints_by_type = []
for keypoint_idx in range(18): # 遍历所有关键点类型
keypoints = extract_keypoints(heatmaps[:, :, keypoint_idx], total_keypoints_num, all_keypoints_by_type,
discard_low_confidence=False)
total_keypoints_num += len(keypoints)
pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, pafs)
return pose_entries, all_keypoints
def resize_image(image, input_height):
h, w, c = image.shape
scale = float(input_height) / min(h, w)
new_h = int(round(float(h) * scale))
new_w = int(round(float(w) * scale))
resized_img = cv2.resize(image, dsize=(new_w, new_h), interpolation=cv2.INTER_LINEAR)
return resized_img, scale, (h, w)
# 测试一张图片
image_path = 'test.jpg' # 替换为实际路径
img = cv2.imread(image_path)
pose_entries, all_keypoints = infer_fast(img, net)
for keypoints in all_keypoints:
for kp_id, kp in enumerate(keypoints):
if kp >= 0:
cv2.circle(img, tuple(map(int, keypoints[kp])), radius=3, color=(0, 255, 255))
cv2.imshow('Result', img)
cv2.waitKey(0)
```
此代码实现了从输入图像中提取骨架的关键点,并绘制在原图上显示结果[^3]。
---
#### 总结
以上分别介绍了两种不同方式的人体姿态估计实现方案:一是经典的 Stacked Hourglass Networks 方法;二是更高效的 Lightweight Human Pose Estimation 工具包。两者各有优劣,在具体应用场景中可根据需求选择合适的算法。