DIoU损失函数
论文链接:https://arxiv.org/pdf/1911.08287
DIoU损失函数(Distance Intersection over Union Loss)是一种在目标检测任务中常用的损失函数,用于优化边界框的位置。这种损失函数是IoU损失函数的改进版,其不仅考虑了边界框之间的重叠区域,还考虑了它们中心点之间的距离,从而提供更加精确的位置优化。以下是DIoU损失函数的设计原理和计算步骤的详细介绍:
设计原理
一、IoU的局限性
- IoU(Intersection over Union)损失函数主要基于预测框和真实框之间的交并比,这个比例值越大表示预测框越接近真实框。
- 但IoU损失函数在预测框和真实框没有重叠时无法提供有效的梯度信息,这限制了模型的学习效率。
二、DIoU的引入
- DIoU损失在IoU的基础上增加了中心点距离的考量,这使得即使在两个框不重叠的情况下也能有效地进行梯度下降。
- 通过考虑框的几何中心距离,DIoU损失有助于减少边界框的尺寸误差,并加速收敛。
计算步骤
一、计算IoU
- 计算两个边界框A和B的交集面积I。
- 计算两个边界框的并集面积U。
- IoU计算公式为:
二、计算框的中心距离
- 设预测框的中心为
,真实框的中心为
。
- 中心点距离 d 的计算公式为:
。
三、计算归一化中心距离
- 计算包围预测框和真实框的最小闭合矩形(称为最小闭合框),并求出其对角线长度。
- 归一化中心距离为
,这样可以确保距离的比例适应不同大小的边界框。
四、计算DIoU
- DIoU损失函数定义为:
- 其中,
表示中心点距离的归一化平方,这样确保了距离项在损失函数中占有合适的权重。
使用PyTorch实现DIoU计算的源代码
import torch
def diou_loss(pred_boxes, gt_boxes):
"""
计算 DIoU 损失。
:param pred_boxes: 预测的边界框,形状为 (batch_size, 4),格式为 (x1, y1, x2, y2)
:param gt_boxes: 真实的边界框,形状为 (batch_size, 4),格式为 (x1, y1, x2, y2)
:return: DIoU 损失值
"""
# 计算交集的坐标
inter_x1 = torch.max(pred_boxes[:, 0], gt_boxes[:, 0])
inter_y1 = torch.max(pred_boxes[:, 1], gt_boxes[:, 1])
inter_x2 = torch.min(pred_boxes[:, 2], gt_boxes[:, 2])
inter_y2 = torch.min(pred_boxes[:, 3], gt_boxes[:, 3])
# 计算交集的面积
inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)
# 计算预测框和真实框的面积
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1])
# 计算并集的面积
union_area = pred_area + g