一、背景
工作项目需要用到关键点检测,yolo系列在工业视觉领域应用相对广泛,技术积累相对成熟(翻译:出bug方便搜),在检测速度和准确率方面也表现优异,毫无疑问成为首选。作为彼时yolo家族的最新款,本人也想尝尝鲜。
原始的yolov7-pose是针对人体17个关键点的检测模型,因此需要对其改造以适应项目需求(项目的使用场景中只需要3个关键点),开始漫长的改造和调试之路。
源码:GitHub - WongKinYiu/yolov7 at pose

二、具体工作
1. 准备数据集
数据标注工具使用的是labelme,生成标注文件,由于labelme无法直接生成yolo格式的标签,因此需要将其生成的coco格式的json文件转换为yolo格式的txt文件。
yolo格式的关键点标签文件内容如下:

前五个为目标类别和标注框中心点和长宽,根据图像宽高统一进行归一化,后面即为关键点坐标和关键点可见度
,
和
同样进行了归一化操作,
,0表示该点不存在,1表示该点不可见(遮挡),2表示该点可见。
转换代码网上很多,不多做介绍,主要参考[1]。
数据集生成之后,在data/coco_kpts.yaml中修改自己的数据集路径及类别信息。

2. 代码修改
数据集准备好之后就可以进行模型训练了,如果需要预训练权重可以从作者github上下载。由于官方提供的是针对人体关键点的检测代码,且并没有提供一个统一的关键点数量修改入口,在数据集加载、损失函数计算、画图等文件里默认都是1个类别和17个关键点,所以训练之前需要自己修改相关代码文件,将其中与关键点数量相关的部分全都修改为自己需要的关键点数量。
(1)cfg/yolov7-w6-pose.yaml
模型结构和anchor的配置文件,将里面的nc和nkpt修改为自己需要检测的类别数和关键点数量。

(2)model/yolo.py
Detect()对应的是yolo的head部分,处理模型最终的预测值,将其转换为我们最终需要的预测坐标和类别信息等。此脚本中作者在yolov5的Detect()中加入了对关键点预测信息的处理方法,并新增了IDetect()和IKeypoint(),三者之间大同小异,用哪个都可以,以下用IKeypoint()为例进行修改。
Line265-266,坐标种类划分
# 将6改为5+nc
x_det = x[i][..., :5 + self.nc]
x_kpt = x[i][..., 5 + self.nc:]
Line283-284,关键点坐标处理
# 17改为nkpt
x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,self.nkpt)) * self.stride[i] # xy
x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,self.nkpt)) * self.stride[i] # xy
此部分修改仅在模型推理的后处理阶段有效,训练过程中并不会用到。
(3)utils/datasets.py
在create_dataloader(),LoadImagesAndLabels(),random_perspective()中添加参数nkpt
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix='', tidl_load=False,
kpt_label=False, nkpt=0):
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix='', square=False, tidl_load=False,
kpt_label=True, nkpt=0):
def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspect

最低0.47元/天 解锁文章
1997





