Focal IoU损失函数
在目标检测任务中,评估预测边界框的质量是一个重要环节。传统的 IoU(Intersection over Union)损失函数虽然能够评估预测边界框与真实边界框的重叠程度,但在某些情况下存在一些问题。例如,当预测边界框与真实边界框的重叠度较低时,IoU 损失函数的梯度会非常小,导致模型难以进行优化。此外,正负样本分布不均衡也是目标检测任务中的一个常见问题,这会影响模型的训练效果。
为了解决这些问题,研究者提出了 Focal IoU Loss。该损失函数旨在平衡高质量样本和低质量样本对损失的贡献,同时强调困难样本在训练过程中的重要性。
设计原理
Focal IoU Loss 的设计原理主要基于以下两个方面:
- 平衡高质量样本和低质量样本对损失的贡献:通过引入一个调节因子,使得高质量样本(即 IoU 值较大的样本)在损失计算中占据更大的权重,而低质量样本(即 IoU 值较小的样本)的权重则相对较小。这样可以在一定程度上避免低质量样本对模型训练的负面影响。
- 强调困难样本在训练过程中的重要性:类似于 Focal Loss,Focal IoU Loss 也为具有挑战性的样本(即 IoU 值较低的样本)分配更高的权重。这样可以使模型在训练过程中更加关注这些困难样本,从而提高模型的泛化能力。
计算步骤
Focal IoU Loss 的计算步骤大致如下:
- 计算预测边界框与真实边界框的 IoU 值。
- 根据 IoU 值的大小,将样本分为高质量样本和低质量样本。具体划分标准可以根据实际情况进行调整。
- 引入一个调节因子 α,用于平衡高质量样本和低质量样本对损失的贡献。α 的取值可以根据实际情况进行调整,通常取值在 0 到 1 之间。
- 对于高质量样本,直接使用原始的 IoU 损失函数进行计算;对于低质量样本,则根据 IoU 值的大小为其分配一个较小的权重。
- 将所有样本的损失进行加权求和,得到最终的 Focal IoU Loss。
添加 Focal IoU损失函数(基于MMYOLO)
由于MMYOLO中没有实现Focal IoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加Focal IoU的计算和对应的iou_mode,修改完以后在终端运行
python setup.py install
再在配置文件中进行修改即可。
包含前面所有IoU损失函数替换的mmyolo/models/iou_loss.py如下(Focal IoU只实现了Focal CIoU,其他同理):
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmdet.models.losses.utils import weight_reduce_loss
from mmdet.structures.bbox import HorizontalBoxes
from mmyolo.registry import MODELS
class WIoU_Scale:
''' monotonous: {
None: origin v1
True: monotonic FM v2
False: non-monotonic FM v3
}
momentum: The momentum of running mean'''
iou_mean = 1.
monotonous = False
_momentum = 1 - 0.5 ** (1 / 7000)
_is_train = True
def __init__(self, iou):
self.iou = iou
self._update(self)
@classmethod
def _update(cls, self):
if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
cls._momentum * self.iou.detach().mean().item()
@classmethod
def _scaled_loss(cls, self, gamma=1.9, delta=3):
if isinstance(self.monotonous, bool):
if self.monotonous:
return (self.iou.detach() / self.iou_mean).sqrt()
else:
beta = self.iou.detach() / self.iou_mean
alpha = delta * torch.pow(gamma, beta - delta)
return beta / alpha
return 1
def bbox_overlaps(pred: torch.Tensor,
target: torch.Tensor,
iou_mode: str = 'ciou',
bbox_format: str = 'xywh',
siou_theta: float = 4.0,
gamma: float = 0.5,
eps: float = 1e-7,
focal: bool = False,) -> torch.Tensor:
r"""Calculate overlap between two set of bboxes.
`Implementation of paper `Enhancing Geometric Factors into
Model Learning and Inference for Object Detection and Instance
Segmentation <https://arxiv.org/abs/2005.03572>`_.
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
difference in the way the alpha parameter is computed.
mmdet version:
alpha = (ious > 0.5).float() * v / (1 - ious + v)
YOLOv5 version:
alpha = v / (v - ious + (1 + eps)
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
or (x, y, w, h),shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou').
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
siou_theta (float): siou_theta for SIoU when calculate shape cost.
Defaults to 4.0.
eps (float): Eps to avoid log(0).
Returns:
Tensor: shape (n, ).
"""
assert iou_mode in ('iou', 'ciou', 'giou', 'siou','diou','eiou','innersiou','innerciou','shapeiou','wiou')
assert bbox_format in ('xyxy', 'xywh')
if bbox_format == 'xywh':
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
target = HorizontalBoxes.cxcywh_to_xyxy(target)
bbox1_x1, bbox1_y1 = pred[..., 0], pred[..., 1]
bbox1_x2, bbox1_y2 = pred[..., 2], pred[..., 3]
bbox2_x1, bbox2_y1 = target[..., 0], target[..., 1]
bbox2_x2, bbox2_y2 = target[..., 2], target[..., 3]
# Overlap
overlap = (torch.min(bbox1_x2, bbox2_x2) -
torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \
(torch.min(bbox1_y2, bbox2_y2) -
torch.max(bbox1_y1, bbox2_y1)).clamp(0)
# Union
w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1
w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1
union = (w1 * h1) + (w2 * h2) - overlap + eps
h1 = bbox1_y2 - bbox1_y1 + eps
h2 = bbox2_y2 - bbox2_y1 + eps
# IoU
ious = overlap / union
# enclose area
enclose_x1y1 = torch.min(pred[..., :2], target[..., :2])
enclose_x2y2 = torch.max(pred[..., 2:], target[..., 2:])
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
enclose_w = enclose_wh[..., 0] # cw
enclose_h = enclose_wh[..., 1] # ch
if iou_mode == 'ciou':
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
# Width and height ratio (v)
wh_ratio = (4