在DETR模型中,在完成DETR模型的构建后,我们送入数据在完成前向传播后就需要使用预测值与真实值进行计算损失来进行反向传播进而更新梯度。
在DETR模型中,其标签匹配采用的是匈牙利匹配算法。
主要涉及models/matcher.py、models/detr.py和engine.py三个文件。
损失计算:SetCriterion
首先在detr.py中会先定义好损失函数:
criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,eos_coef=args.eos_coef, losses=losses)
criterion.to(device)
然后在engine.py的train_one_epoch中前向推理结束后将输出结果与真实标签作为参数并调用criterion函数,计算损失:
# 计算损失 loss_dict: 'loss_ce' + 'loss_bbox' + 'loss_giou' 用于log日志: 'class_error' + 'cardinality_error'
loss_dict = criterion(outputs, targets)
# 权重系数 {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}
weight_dict = criterion.weight_dict
# 总损失 = 回归损失:loss_bbox(L1)+loss_bbox + 分类损失:loss_ce
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
outputs的值
pred_boxes:torch.Size([2, 100, 4])
pred_logits:torch.Size([2, 100, 7]),这里的logits本为表示全连接层的输出,这里指的就是标签。
‘aux_output’={list:5} 0-4 每个都是dict:2 pred_logits+pred_boxes 表示5个decoder前面层的输出
target的值
,batch-size=2,故其图片有两张,第一张中有标签4个,第二张中有标签21个。

下面来看一下SetCriterion的具体工作:
首先获取decoder最后一层的输出结果,分别对应类别预测与box预测,分别为torch.Size([2, 100, 7])与torch.Size([2, 100, 4]),即共两张图片,每张图片输出100个预测结果,分别为7个类别信息和box信息:x,y,w,h
outputs_without_aux = {
k: v for k, v in outputs.items() if k != 'aux_outputs'}
使用匈牙利匹配算法来进行预测框与真实框匹配,这里可以看到通过匈牙利匹配的结果:
indices = self.matcher(outputs_without_aux, targets)

如:(tensor([41, 51, 75, 87]), tensor([0, 3, 1, 2]))即第一张图片中总共有4个标签(真实框),那么匹配上的100个预测结果的编号为:41, 51, 75, 87。
随后统计该batch内的标签数量:
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
随后便是计算损失值了。
完整代码实现如下:
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
eos_coef: relative classification weight applied to the no-object category
losses: list of all the losses to be applied. See get_loss for list of available losses.
"""
super().__init__()
self.num_classes = num_classes # 数据集类别数
self.matcher = matcher # HungarianMatcher() 匈牙利算法 二分图匹配
self.weight_dict = weight_dict # dict: 18 3x6 6个decoder的损失权重 6*(loss_ce+loss_giou+loss_bbox)
self.eos_coef = eos_coef # 0.1
self.losses = losses # list: 3 ['labels', 'boxes', 'cardinality']
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef # tensro: 92 前91=1 92=eos_coef=0.1
self.register_buffer('empty_weight', empty_weight)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
dict: 'pred_logits'=Tensor[bs, 100, 92个class] 'pred_boxes'=Tensor[bs, 100, 4] 最后一个decoder层输出
'aux_output'={list:5} 0-4 每个都是dict:2 pred_logits+pred_boxes 表示5个decoder前面层的输出
targets: list of dicts, such that len(targets) == batch_size. list: bs
每张图片包含以下信息:'boxes'、'labels'、'image_id'、'area'、'iscrowd'、'orig_size'、'size'
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
# dict: 2 最后一个decoder层输出 pred_logits[bs, 100, 7个class] + pred_boxes[bs, 100, 4]
outputs_without_aux = {
k: v for k, v in outputs.items() if k != 'aux_outputs'}
# 匈牙利算法 解决二分图匹配问题 从100个预测框中找到和N个gt框一一对应的预测框 其他的100-N个都变为背景
# Retrieve the matching between the outputs of the last layer and the targets list:1
# tuple: 2 0=Tensor3=Tensor[5, 35, 63] 匹配到的3个预测框 其他的97个预测框都是背景
# 1=Tensor3=Tensor[1, 0, 2] 对应的三个gt框
indices = self

DETR模型使用SetCriterion计算损失,包括分类、回归和GIoU损失。匈牙利匹配算法用于确定预测框与真实框的匹配,以计算正样本的损失。损失函数考虑了预测框与目标框的匹配,以及匹配后的分类和定位误差。在训练过程中,匹配后的预测框与真实框的损失会被用来更新模型的梯度。
最低0.47元/天 解锁文章
2968

被折叠的 条评论
为什么被折叠?



