关键点检测(8)——yolov8-loss的代码解析

  我们知道yolov8在流行的yolov5的架构上进行了扩展。在多个方面提供了改进。Loss 计算方面采用了 TaskAlignedAssigner 正样本分配策略,并引入了 Distribution Focal Loss。实际上可以看出,YOLOv8 主要参考了最近提出的诸如 YOLOX、YOLOv6、YOLOv7 和 PPYOLOE 等算法的相关设计,本身的创新点不多,偏向工程实践,主推的还是 ultralytics 这个框架本身。

  but anyway,我们还是学习一下这个集大成者。

 前面学习了yolov8的backbone,neck和head操作。特征提取完后就需要进行loss计算了。

1,yolov8-pose loss代码总体解析

  我们以关键点检测为例。来学习Yolov8中训练过程中的损失函数计算。

  正常流程都是模型在训练的时候,即调用forward()时候,会调用self.loss()函数。

  那么我们就首先看看模型的定义,回到代码进行check。因为我们是以keypoint detection为主。所以仍然从Pose开始学习。

        首先通过PoseModel进行寻找,如下所示,PoseModel里面实际上什么也没有写,除了init_criterion()函数。

        我们继续进行check,因为PoseModel 是继承了DetectionModel,而DetectionModel 也是重写了init_criterion()函数,其他也没有有用的信息。

        我们继续check其基类BaseModel,在这里我们发现了猫腻,找到了我们需要的函数loss function:

        我们可以看到,第一次调用self.loss()时,是通过init_criterion()初始化损失函数模块,然后使用self.forward()函数得到预测结果,最后使用self.criterion()函数来计算损失。而所谓的self.criterion()函数就等于 self.init_criterion()。因为我们所用的是Pose模型,那么PoseModel重写了 self.init_criterion()函数,也就是最初的V8PoseLoss。通过V8PoseLoss来计算损失,然后供模型训练的反向传播更新参数使用。

  既然整个流程我们清楚了,那么然后去check 对应的V8PoseLoss函数。

        首先我们可以从下图看到yolov8是将各类任务(图像分类,目标检测,关键点检测,实例分割,旋转目标检测)的损失函数都放在了utils/loss.py中了。

   然后我们进入V8PoseLoss函数中:

class  v8PoseLoss(v8DetectionLoss):
    """Criterion class for computing training losses."""

    def __init__(self, model):  # model must be de-paralleled
        """Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
        super().__init__(model)
        self.kpt_shape = model.model[-1].kpt_shape
        self.bce_pose = nn.BCEWithLogitsLoss()
        is_pose = self.kpt_shape == [17, 3]
        nkpt = self.kpt_shape[0]  # number of keypoints
        sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
        self.keypoint_loss = KeypointLoss(sigmas=sigmas)

    def __call__(self, preds, batch):
        """Calculate the total loss and detach it."""
        loss = torch.zeros(5, device=self.device)  # box, cls, dfl, kpt_location, kpt_visibility
        feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
            (self.reg_max * 4, self.nc), 1
        )

        # B, grids, ..
        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
        pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()

        dtype = pred_scores.dtype
        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

        # Targets
        batch_size = pred_scores.shape[0]
        batch_idx = batch["batch_idx"].view(-1, 1)
        targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)

        # Pboxes
        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
        pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape))  # (b, h*w, 17, 3)

        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
            pred_scores.detach().sigmoid(),
            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
            anchor_points * stride_tensor,
            gt_labels,
            gt_bboxes,
            mask_gt,
        )

        target_scores_sum = max(target_scores.sum(), 1)

        # Cls loss
        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
        loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE

        # Bbox loss
        if fg_mask.sum():
            target_bboxes /= stride_tensor
            loss[0], loss[4] = self.bbox_loss(
                pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
            )
            keypoints = batch["keypoints"].to(self.device).float().clone()
            keypoints[..., 0] *= imgsz[1]
            keypoints[..., 1] *= imgsz[0]

            loss[1], loss[2] = self.calculate_keypoints_loss(
                fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
            )

        loss[0] *= self.hyp.box  # box gain
        loss[1] *= self.hyp.pose  # pose gain
        loss[2] *= self.hyp.kobj  # kobj gain
        loss[3] *= self.hyp.cls  # cls gain
        loss[4] *= self.hyp.dfl  # dfl gain

        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)

   我们简单梳理一下这个V8PoseLoss的类。它首先是继承V8DetectionLoss这个类。其主要用于计算姿态估计任务中的损失。包括边界框(bbox)损失,分类(cls)损失,分布焦点损失(DFL),关键点位置损失(kpt_location)和关键点可见性损失(kpt_visibility)。

  其初始化方法主要是一些定义:

  • 设置关键点形状:从模型中获取关键点的形状kpt_shape。
  • 创建BCE损失函数:用于关键点可见性的二元交叉熵损失。
  • 判断是否为姿态任务:如果kpt_shape 是[17, 3],则认为是姿态任务(注意:这个个人理解主要来来查看sigmas的,如果不存在,那么默认设置sigmas)
  • 设置关键点数量:nkpt是关键点的数量
  • 设置关键点损失的权重:如果是姿态任务,使用预定义的OKS_SIGMA;否则使用均匀权重(
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

战争热诚

如果帮助到你,可以请我喝杯咖啡

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值