【Focal Loss】简单理解 及 Pytorch 代码 Focal Loss for Dense Object Detection

本文深入探讨Focal Loss与Lovasz Softmax在解决类别不平衡问题上的应用,解析其数学原理与代码实现,尤其针对图像分割任务,提供详实的Python代码示例。
部署运行你感兴趣的模型镜像

 

一、首先回顾下“交叉熵loss Cross Entropy Loss” 

      CE(Pi)=-log(Pi)

二、一般地说,我们数据集会存在类别不平衡问题,很多人会在loss上对应不同类别设置不同系数

                    

       loss就变成了上面的样子

三、Focal loss其实就是通过数学公式上的改变,扩大了不平衡因素在loss上的影响

                           

        这里的at 依然是和“二里面的”那个手动的一个意思!,这是他们文中的意思!我们来分析下!首先,这个的存在,无论为“预测器”分对分错,都会减小loss的值!!!

                        

          从上面一段话,我们可以看出,当某一个类别分类预测概率越来越高(分类准确的情况下),那么loss会因为的存在而被大大降低!这也是我们所希望的!!!当一个类别预测概率小于0.5,也就是基本上被判断到了别的类别去了!分错了!loss相对于“分对了”的情况而言最多只能减少4倍!!!loss减少量是“分对情况”下的“扩大”几十倍,百倍,千倍!

四、代码

这些,github上也有,那么怎么用呢?,顺边给大家lovasz_softmax loss的代码

使用小例子,图像分割为例子

#正常ce loss
loss  = F.bceloss(F.softmax(pred), target)

#Focal loss
loss  = FocalLoss2d(F.softmax(pred), target)

#lovasz_softmax
loss  = lovasz_softmax(F.softmax(pred1,dim=1), target)
class FocalLoss2d(torch.nn.Module):
    def __init__(self,  weight=None, size_average=True, ignore_index=255):
        super(FocalLoss2d, self).__init__()
        self.gamma = 2
        self.nll_loss = torch.nn.NLLLoss2d(weight, size_average)
    def forward(self, inputs, targets):
        return self.nll_loss((1 - F.softmax(inputs,1)) ** self.gamma * F.log_softmax(inputs,1), targets)

def softmax_focalloss(input, target, weight, size_average, ignore_index, reduce=True):
    return F.nll_loss((1 - F.softmax(inputs,1)) ** gamma * F.log_softmax(inputs,1), targets)
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

from __future__ import print_function, division

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / union
        ious.append(iou)
    iou = mean(ious)    # mean accross images if per_image
    return 100 * iou


def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / union)
        ious.append(iou)
    ious = map(mean, zip(*ious)) # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
         super(StableBCELoss, self).__init__()
    def forward(self, input, target):
         neg_abs = - input.abs()
         loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
         return loss.mean()


def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss


# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
    return loss


def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)


# --------------------------- HELPER FUNCTIONS ---------------------------

def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(np.isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

 

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

### Focal Loss 论文 PDF 原版下载 Focal Loss 的原始论文名为《Focal Loss for Dense Object Detection》,由 Lin et al. 提出。该论文的主要贡献在于引入了一种新的损失函数——Focal Loss,用于解决密集目标检测中的类别不平衡问题[^3]。 以下是获取原版论文的方法: 1. **官方链接**: 可通过 ArXiv 平台访问论文的公开版本。具体地址为: [https://arxiv.org/abs/1708.02002](https://arxiv.org/abs/1708.02002) 2. **GitHub 资源**: 如果需要更详细的解读或相关实现代码,可以通过 GitHub 上的相关项目找到补充资料。例如,您提到的一个资源已经提供了关于 Focal Loss 的公式推导和解释[^1]。 3. **学术搜索引擎**: 使用 Google Scholar 或其他学术搜索引擎输入关键词 “Focal Loss for Dense Object Detection”,即可找到多种格式的下载选项。 #### 关于 Focal Loss 的核心概念 Focal Loss 是在标准交叉熵基础上改进的一种损失函数,其主要目的是降低容易分类样本的影响,使得模型能够更多关注难以分类的样本。公式定义如下: \[ FL(p_t) = -(1-p_t)^\gamma \log(p_t) \] 其中 \( p_t \) 表示预测的概率值,\( \gamma \) 是可调参数,通常取值大于零。 ```python import torch import torch.nn as nn class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim()>2: input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W input = input.transpose(1,2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1,1) logpt = nn.functional.log_softmax(input, dim=-1) logpt = logpt.gather(1,target) logpt = logpt.view(-1) pt = torch.exp(logpt) if self.alpha is not None: if self.alpha.type()!=input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0,target.data.view(-1)) logpt = logpt * at loss = -1 * (1-pt)**self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() ``` 上述代码实现了 Focal LossPyTorch 版本,便于实际应用时参考。 ---
评论 6
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值