我们知道yolov8在流行的yolov5的架构上进行了扩展。在多个方面提供了改进。Loss 计算方面采用了 TaskAlignedAssigner 正样本分配策略,并引入了 Distribution Focal Loss。实际上可以看出,YOLOv8 主要参考了最近提出的诸如 YOLOX、YOLOv6、YOLOv7 和 PPYOLOE 等算法的相关设计,本身的创新点不多,偏向工程实践,主推的还是 ultralytics 这个框架本身。
but anyway,我们还是学习一下这个集大成者。
前面学习了yolov8的backbone,neck和head操作。特征提取完后就需要进行loss计算了。
1,yolov8-pose loss代码总体解析
我们以关键点检测为例。来学习Yolov8中训练过程中的损失函数计算。
正常流程都是模型在训练的时候,即调用forward()时候,会调用self.loss()函数。
那么我们就首先看看模型的定义,回到代码进行check。因为我们是以keypoint detection为主。所以仍然从Pose开始学习。
首先通过PoseModel进行寻找,如下所示,PoseModel里面实际上什么也没有写,除了init_criterion()函数。
我们继续进行check,因为PoseModel 是继承了DetectionModel,而DetectionModel 也是重写了init_criterion()函数,其他也没有有用的信息。
我们继续check其基类BaseModel,在这里我们发现了猫腻,找到了我们需要的函数loss function:
我们可以看到,第一次调用self.loss()时,是通过init_criterion()初始化损失函数模块,然后使用self.forward()函数得到预测结果,最后使用self.criterion()函数来计算损失。而所谓的self.criterion()函数就等于 self.init_criterion()。因为我们所用的是Pose模型,那么PoseModel重写了 self.init_criterion()函数,也就是最初的V8PoseLoss。通过V8PoseLoss来计算损失,然后供模型训练的反向传播更新参数使用。
既然整个流程我们清楚了,那么然后去check 对应的V8PoseLoss函数。
首先我们可以从下图看到yolov8是将各类任务(图像分类,目标检测,关键点检测,实例分割,旋转目标检测)的损失函数都放在了utils/loss.py中了。
然后我们进入V8PoseLoss函数中:
class v8PoseLoss(v8DetectionLoss):
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0] # number of keypoints
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds, batch):
"""Calculate the total loss and detach it."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
batch_size = pred_scores.shape[0]
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
keypoints = batch["keypoints"].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
loss[1], loss[2] = self.calculate_keypoints_loss(
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose # pose gain
loss[2] *= self.hyp.kobj # kobj gain
loss[3] *= self.hyp.cls # cls gain
loss[4] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
我们简单梳理一下这个V8PoseLoss的类。它首先是继承V8DetectionLoss这个类。其主要用于计算姿态估计任务中的损失。包括边界框(bbox)损失,分类(cls)损失,分布焦点损失(DFL),关键点位置损失(kpt_location)和关键点可见性损失(kpt_visibility)。
其初始化方法主要是一些定义:
- 设置关键点形状:从模型中获取关键点的形状kpt_shape。
- 创建BCE损失函数:用于关键点可见性的二元交叉熵损失。
- 判断是否为姿态任务:如果kpt_shape 是[17, 3],则认为是姿态任务(注意:这个个人理解主要来来查看sigmas的,如果不存在,那么默认设置sigmas)
- 设置关键点数量:nkpt是关键点的数量
- 设置关键点损失的权重:如果是姿态任务,使用预定义的OKS_SIGMA;否则使用均匀权重(