语义分割之dice loss深度分析(梯度可视化)

Dice损失函数源自V-Net,用于处理医学图像语义分割中的正负样本不平衡问题。它基于Dice系数,强调了像素间的相关性,使得在网络训练中更关注前景区域。相较于交叉熵损失,Dice损失在正样本上的梯度更大,有利于减少假阴性。然而,它存在梯度饱和问题,且训练过程中损失可能不稳定。为改善这种情况,通常会与其他损失函数结合使用,如Dice+Celoss或Dice+Focal Loss。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

dice loss 来自文章VNet(V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation),旨在应对语义分割中正负样本强烈不平衡的场景。本文通过理论推导和实验验证的方式对dice loss进行解析,帮助大家去更好的理解和使用。

dice loss 定义

dice loss 来自 dice coefficient,是一种用于评估两个样本的相似性的度量函数,取值范围在0到1之间,取值越大表示越相似。dice coefficient定义如下:

 

 

def dice_loss(target,predictive,ep=1e-8):
    intersection = 2 * torch.sum(predictive * target) + ep
    union = torch。sum(predictive) + torch.sum(target) + ep
    loss = 1 - intersection / union
    return loss

梯度分析

从dice loss的定义可以看出,dice loss 是一种区域相关的loss。意味着某像素点的loss以及梯度值不仅和该点的label以及预测值相关,和其他点的label以及预测值也相关,这点和ce (交叉熵cross entropy) loss 不同。因此分析起来比较复杂,这里我们简化一下,首先从loss曲线和求导曲线对单点输出方式分析。然后对于多点输出的情况,利用模拟预测输出来分析其梯度。而多分类softmax是sigmoid的一种推广,本质一样,所以这里只考虑sigmoid输出的二分类问题,首先sigmoid函数定义如下:

 

 

 

 

可以看出:

  • 一般情况下,dice loss 正样本的梯度大于背景样本的; 尤其是刚开始网络预测接近0.5的时候,这点和单点输出的现象一致。说明 dice loss 更具有指向性,更加偏向于正样本,保证有较低的FN。
  • 负样本(背景区域)也会产生梯度。
  • 极端情况下,网络预测接近0或1时,对应点梯度值极小,dice loss 存在梯度饱和现象。此时预测失败(FN,FP)的情况很难扭转回来。不过该情况出现的概率较低,因为网络初始化输出接近0.5,此时具有较大的梯度值。而网络通过梯度下降的方式更新参数,只会逐渐削弱预测失败的像素点。
  • 对于ce loss,当前的点的梯度仅和当前预测值与label的距离相关,预测越接近label,梯度越小。当网络预测接近0或1时,梯度依然保持该特性。
  • 对比发现, 训练前中期,dice loss下正样本的梯度值相对于ce loss,颜色更亮,值更大。说明dice loss 对挖掘正样本更加有优势。

dice loss为何能够解决正负样本不平衡问题?

因为dice loss是一个区域相关的loss。区域相关的意思就是,当前像素的loss不光和当前像素的预测值相关,和其他点的值也相关。dice loss的求交的形式可以理解为mask掩码操作,因此不管图片有多大,固定大小的正样本的区域计算的loss是一样的,对网络起到的监督贡献不会随着图片的大小而变化。从上图可视化也发现,训练更倾向于挖掘前景区域,正负样本不平衡的情况就是前景占比较小。而ce loss 会公平处理正负样本,当出现正样本占比较小时,就会被更多的负样本淹没。

总结

dice loss 对正负样本严重不平衡的场景有着不错的性能,训练过程中更侧重对前景区域的挖掘。但训练loss容易不稳定,尤其是小目标的情况下。另外极端情况会导致梯度饱和现象。因此有一些改进操作,主要是结合ce loss等改进,比如: dice+ce loss,dice + focal loss等,本文不再论述。


 

 

 

### CamVid 数据集在语义分割任务中的方法和实现 #### 数据集概述 CamVid数据集提供了丰富的图像资源,特别适用于智能驾驶场景下的语义分割研究。该数据集包含超过700张高质量标注的图片,并分为训练集、验证集和测试集[^3]。每幅图像是从车辆视角拍摄而来,涵盖了多种复杂的交通环境。 #### 类别定义 为了便于评估模型性能,在实际操作过程中通常采用简化后的11种类别进行预测结果分析,具体如下: - 汽车 (Car) - 天空 (Sky) - 行人道 (Sidewalk) - 电线杆 (Pole) - 围墙 (Fence) - 行人 (Pedestrian) - 建筑物 (Building) - 自行车骑行者 (Bicyclist) - 树木 (Tree) 这些类别通过CSV文件记录着各自对应的RGB颜色编码值,方便后续处理时快速查找并应用于可视化展示环节中。 #### U-Net与DeepLabv3+的应用实例 针对上述提到的数据特点,可以选用U-Net或DeepLabv3+这两种先进的神经网络架构来完成具体的语义分割工作: ##### U-Net结构设计 U-Net是一种经典的编解码器框架,其特点是具有跳跃连接机制能够有效保留原始输入的空间信息。对于CamVid这样的高分辨率街景影像来说非常适合。以下是基于PyTorch平台构建的一个简单版本U-Net模型代码片段: ```python import torch.nn as nn class UNet(nn.Module): def __init__(self, n_channels=3, n_classes=11): super(UNet, self).__init__() # Encoder部分省略... # Decoder部分同样省略... def forward(self, x): # 正向传播逻辑也已省去 return output ``` ##### DeepLabv3+改进之处 而相比之下,DeepLabv3+则进一步优化了特征提取能力,特别是在多尺度上下文中表现更为出色。它引入ASPP模块(Atrous Spatial Pyramid Pooling),可以在不增加太多计算量的情况下捕捉到更广泛的感受野范围内的细节变化情况。下面给出一段创建DeepLabv3+模型对象的例子: ```python from torchvision.models.segmentation import deeplabv3_resnet50 model = deeplabv3_resnet50(pretrained=False, progress=True, num_classes=11) ``` #### 训练流程概览 无论是哪种类型的网络都需要经历以下几个阶段才能最终得到满意的分割效果: - **准备数据加载器**:读取本地存储路径下经过预处理过的样本集合; - **设定超参数配置**:比如批次大小(batch size),迭代次数(epoch number)等; - **初始化权重矩阵**:可选地使用预训练好的基础骨干网作为起点加速收敛过程; - **选择合适的损失函数**:交叉熵(Cross Entropy Loss)是较为常用的一种方式; - **执行反向传播更新梯度**:借助Adam或其他自适应学习率调整策略提高效率; - **定期保存最佳状态字典**:以便随时恢复最优参数组合继续调优或者部署上线; #### 性能评价指标考量 当完成了整个训练周期之后,则需依赖一些量化标准衡量所建立系统的准确性高低。常见的有IoU(intersection over union), Dice Coefficient以及其他衍生出来的变种形式等等。它们都能够在一定程度上反映出算法识别各类物体边界的精确程度如何。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值