Yolo v8的损失函数使用的是交叉熵损失函数。但是对于难分类的样本检测效果较差,如垃圾、垃圾桶。
Focal Loss 是一种用于解决类别不平衡问题的损失函数,因此尝试将其引入v8结构。
ps:原本v8的loss.py文件中自带了focal loss的函数,但是调用自带代码,报错。原因未知。因此,尝试重新修改Focal-loss部分的代码,并记录过程:
1、找到yolo/utils/loss.py

2、添加新的focal_loss的代码
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2, num_classes=80, size_average=True):
"""
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
步骤详细的实现了 focal_loss损失函数.
:param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
:param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2
:param num_classes: 类别数量
:param size_average: 损失计算方式,默认取均值
"""
super(FocalLoss, self).__init__()
self.size_average = size_average
if alpha is None:
self.alpha = torch.ones(num_classes)
elif isinstance(alpha, list):
assert len(alpha) == num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
self.alpha = torch.Tensor(alpha)
else:
assert alpha < 1 # 如果α为一个常数,则降低第一类的影响,在目标检测中第一类为背景类
self.alpha = torch.zeros(num_classes)
self.alpha[0] += alpha
self.alpha[1:] += (1 - alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
self.gamma = gamma
# print('Focal Loss:')
# print(' Alpha = {}'.format(self.alpha))
# print(' Gamma = {}'.format(self.gamma))
def forward(self, preds, labels):
preds = preds.view(-1, preds.size(-1))
alpha = self.alpha.to(preds.device)
preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
preds_softmax = torch.exp(preds_logsoft) # softmax
labels = labels.to(torch.int64)
preds_softmax = preds_softmax.gather(1, labels.view(-1, 1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1))
alpha = self.alpha.gather(0, labels.view(-1))
print(alpha.is_cuda, labels.is_cuda)
loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma),
preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
loss = torch.mul(alpha, loss.t())
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
2、在loss.py中找到DetectionLoss类(我是做检测所以只修改了detection的loss)

3、定义focal_loss变量

3、将分类的loss改为focal_loss

4、正常训练即可
ps:会报int64类型错误,需要在focal_loss类里加上一句
labels = labels.to(torch.int64)
文章介绍了在Yolov8模型中遇到的对于难分类样本(如垃圾、垃圾桶)检测效果不佳的问题。作者尝试使用FocalLoss来改善这种情况,因为FocalLoss能有效处理类别不平衡问题。在实现过程中,作者在loss.py文件中添加了自定义的FocalLoss类,并修改了DetectionLoss类以使用FocalLoss代替原来的交叉熵损失。虽然遇到了类型错误,但通过将标签转换为int64解决了这个问题,从而能够正常进行训练。

2万+





