yolov8/yolov10 损失函数代码解析

torch文档:https://pytorch.org/docs/stable/tensors.html

损失函数

代码位置https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/loss.py
假设batch_size=4 number_classes=6 输入img_size=640(训练过程中自动转换为[640,640])

class v8DetectionLoss:
    """Criterion class for computing training losses."""

    def __init__(self, model, tal_topk=10):  # model must be de-paralleled
        """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
        device = next(model.parameters()).device  # get model device
        h = model.args  # 超参数

        m = model.model[-1]  # Detect() module
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
        self.hyp = h
        self.stride = m.stride  # 步长 8 16 32 
        # 假设输入图片640x640 
        # 那么对应anchor就是 [640/8,640/8][640/16,640/16][640/32,640/32]
        # ->[80,80][40,40][20,20]->6400+1600+400=8400这里和yolov5计算相同
        self.nc = m.nc  # 类别数量 假设为6方便计算其他值
        self.no = m.nc + m.reg_max * 4	# 每个anchor输出数量 类别+位置*dfl通道数=6+16*4=
        self.reg_max = m.reg_max	# dfl损失通道数,固定16
        self.device = device

        self.use_dfl = m.reg_max > 1

        self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
        self.bbox_loss = BboxLoss(m.reg_max).to(device)
        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)	# [0,1,...15]

    def preprocess(self, targets, batch_size, scale_tensor):
        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
        nl, ne = targets.shape	# nl目标数量  ne=6对应真实的坐标xywh+batch中img的id+类别
        if nl == 0:
            out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
        else:
            i = targets[:, 0]  # 对应batch['batch_idx'] batch中img的id
            _, counts = i.unique(return_counts=True)	# 统计重复元素个数
            counts = counts.to(dtype=torch.int32)
            out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)	# 输出张量
            for j in range(batch_size):
                matches = i == j	# 匹配第j个batch中的img_id目标
                n = matches.sum()	# 统计上述个数
                if n:	# 如果gt有目标 (一张训练的图片中有标注目标)
                    out[j, :n] = targets[matches, 1:]	# 取出对应目标的类别和box
            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) # box的真实坐标xyxy
        return out

    def bbox_decode(self, anchor_points, pred_dist):
        """Decode predicted object bounding box coordinates from anchor points and distribution."""
        if self.use_dfl:
            b, a, c = pred_dist.shape  # batch, anchors, channels=m.reg_max*4
            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
            # 这里假设类别为6 batch_size为4,那么b,a,c=4,8400,64
            # pred_dist.view -> 4,8400,64 -> 4,8400,4,16 最后一维对应dfl通道数
            # softmax(3)对上诉的结果的dim=3(对dfl通道16那维)进行归一化概率(和为1,实际为 e^i/e^(all) )
            # matmul这里直接理解为矩阵乘法  4,4800,4,16 * 16 = 4,4800,4  所以上诉两行代码就是做dfl矩阵乘法
        return dist2bbox(pred_dist, anchor_points, xywh=False)

    def __call__(self, preds, batch):
        """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
        loss = torch.zeros(3, device=self.device)  # box, cls, dfl 这里可以看出损失分为box,分类,dfl
        feats = preds[1] if isinstance(preds, tuple) else preds	
        
        # 相当于 torch.cat((feat[0].view(4,70,-1), (feat[1].view(4,70,-1), (feat[2].view(4,70,-1),).split(16*4,6),1)
        # 说人话就是三种步长的特征值(前向传播结果)连起来再单独把作为分类得分的为一组,其余的为一组  
        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -</
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值