DETR源码学习(三)之损失函数与后处理

DETR模型使用SetCriterion计算损失,包括分类、回归和GIoU损失。匈牙利匹配算法用于确定预测框与真实框的匹配,以计算正样本的损失。损失函数考虑了预测框与目标框的匹配,以及匹配后的分类和定位误差。在训练过程中,匹配后的预测框与真实框的损失会被用来更新模型的梯度。

在DETR模型中,在完成DETR模型的构建后,我们送入数据在完成前向传播后就需要使用预测值与真实值进行计算损失来进行反向传播进而更新梯度。
在DETR模型中,其标签匹配采用的是匈牙利匹配算法。
主要涉及models/matcher.py、models/detr.py和engine.py三个文件。

损失计算:SetCriterion

首先在detr.py中会先定义好损失函数:

criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,eos_coef=args.eos_coef, losses=losses)
criterion.to(device)

然后在engine.py的train_one_epoch中前向推理结束后将输出结果与真实标签作为参数并调用criterion函数,计算损失:

# 计算损失  loss_dict: 'loss_ce' + 'loss_bbox' + 'loss_giou'    用于log日志: 'class_error' + 'cardinality_error'
loss_dict = criterion(outputs, targets)
# 权重系数 {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}
weight_dict = criterion.weight_dict   
# 总损失 = 回归损失:loss_bbox(L1)+loss_bbox  +   分类损失:loss_ce
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

outputs的值
pred_boxes:torch.Size([2, 100, 4])
pred_logits:torch.Size([2, 100, 7]),这里的logits本为表示全连接层的输出,这里指的就是标签。
‘aux_output’={list:5} 0-4 每个都是dict:2 pred_logits+pred_boxes 表示5个decoder前面层的输出

target的值

,batch-size=2,故其图片有两张,第一张中有标签4个,第二张中有标签21个。

在这里插入图片描述
下面来看一下SetCriterion的具体工作:

首先获取decoder最后一层的输出结果,分别对应类别预测与box预测,分别为torch.Size([2, 100, 7])与torch.Size([2, 100, 4]),即共两张图片,每张图片输出100个预测结果,分别为7个类别信息和box信息:x,y,w,h

outputs_without_aux = {
   
   k: v for k, v in outputs.items() if k != 'aux_outputs'}

使用匈牙利匹配算法来进行预测框与真实框匹配,这里可以看到通过匈牙利匹配的结果:

indices = self.matcher(outputs_without_aux, targets)

在这里插入图片描述

如:(tensor([41, 51, 75, 87]), tensor([0, 3, 1, 2]))即第一张图片中总共有4个标签(真实框),那么匹配上的100个预测结果的编号为:41, 51, 75, 87。
随后统计该batch内的标签数量:

num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)

随后便是计算损失值了。

完整代码实现如下:

class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes     # 数据集类别数
        self.matcher = matcher             # HungarianMatcher()  匈牙利算法 二分图匹配
        self.weight_dict = weight_dict     # dict: 18  3x6  6个decoder的损失权重   6*(loss_ce+loss_giou+loss_bbox)
        self.eos_coef = eos_coef           # 0.1
        self.losses = losses               # list: 3  ['labels', 'boxes', 'cardinality']
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef   # tensro: 92   前91=1  92=eos_coef=0.1
        self.register_buffer('empty_weight', empty_weight)
        
    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
                      dict: 'pred_logits'=Tensor[bs, 100, 92个class]  'pred_boxes'=Tensor[bs, 100, 4]  最后一个decoder层输出
                             'aux_output'={list:5}  0-4  每个都是dict:2 pred_logits+pred_boxes 表示5个decoder前面层的输出
             targets: list of dicts, such that len(targets) == batch_size.   list: bs
                      每张图片包含以下信息:'boxes'、'labels'、'image_id'、'area'、'iscrowd'、'orig_size'、'size'
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        # dict: 2   最后一个decoder层输出  pred_logits[bs, 100, 7个class] + pred_boxes[bs, 100, 4]
        outputs_without_aux = {
   
   k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # 匈牙利算法  解决二分图匹配问题  从100个预测框中找到和N个gt框一一对应的预测框  其他的100-N个都变为背景
        # Retrieve the matching between the outputs of the last layer and the targets  list:1
        # tuple: 2    0=Tensor3=Tensor[5, 35, 63]  匹配到的3个预测框  其他的97个预测框都是背景
        #             1=Tensor3=Tensor[1, 0, 2]    对应的三个gt框
        indices = self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彭祥.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值