【代码实践】dice loss及其变形原理详细讲解和图像分割实践(含源码)
Dice Loss 是一种用于图像分割任务的损失函数,旨在解决类别不平衡的问题。它的名字来源于Dice系数,一种用于度量两个集合相似性的统计指标。Dice Loss的优势在于它对类别不平衡问题具有一定的鲁棒性。它能够有效地应对分割任务中正负样本数量差异较大的情况,使模型更关注小类别的分割效果。因此,Dice Loss在医学图像分割等领域广泛应用,其中正样本(感兴趣区域)通常较小。
需要注意的是,Dice Loss的计算方式使其对预测和标签的相对位置不敏感,因此在某些情况下,模型可能更容易收敛。然而,Dice Loss也有一些缺点,例如在存在类别间重叠时,它可能不够精确。
文章目录
1. Dice Loss定义
在图像分割中,Dice Loss的计算涉及到预测的二值化掩码和实际标签的二值化掩码。设定预测的二值掩码为A,实际标签的二值掩码为B,则Dice系数的计算公式为:
D
i
c
e
(
A
,
B
)
=
∣
A
∣
+
∣
B
∣
2
⋅
∣
A
∩
B
∣
Dice(A,B)= \frac{∣A∣+∣B∣}{2⋅∣A∩B∣}
Dice(A,B)=2⋅∣A∩B∣∣A∣+∣B∣
Dice Loss则是将Dice系数转化为损失的形式,其计算公式为:
D
i
c
e
L
o
s
s
(
A
,
B
)
=
1
−
∣
A
∣
+
∣
B
∣
2
⋅
∣
A
∩
B
∣
(
1
)
DiceLoss(A,B)=1− \frac{∣A∣+∣B∣}{2⋅∣A∩B∣}\quad(1)
DiceLoss(A,B)=1−2⋅∣A∩B∣∣A∣+∣B∣(1)
在这个公式中,分子部分表示预测掩码和实际标签的交集的两倍,分母部分则是它们的并集。Dice Loss的值越小,表示预测结果越接近实际标签。
2. Dice loss 结合 cross entropy loss
在目前的研究中,常常单独使用 focal loss,交叉熵或者 dice loss 作为损失函数进行
网络训练,但是因为硬件限制和数据量限制,使得在个人训练时使用 focal loss 和 dice loss作为损失函数,模型难以收敛到较优状态。而使用交叉熵损失函数则有着无法起到平衡正负样本的问题,因此本文采用了 dice loss 和交叉熵损失函数结合的方法来解决难以收敛和无法平衡正负样本的问题。
2.1 Log-Dice_Ce Loss公式
使用
l
o
g
log
log函数将交叉熵损失函数和dice loss损失函数同步到同一量级 ,同时根据作者验证, 使用
l
o
g
log
log函数有利于训练收敛和正负样本平衡。
公式如下:
D
i
c
e
C
e
L
o
s
s
=
L
o
g
(
D
i
c
e
L
o
s
s
)
+
C
e
L
o
s
s
DiceCeLoss=Log(DiceLoss)+CeLoss
DiceCeLoss=Log(DiceLoss)+CeLoss
其中
D
i
c
e
L
o
s
s
DiceLoss
DiceLoss见公式(1),CeLoss具体可以关注博主另外一篇文章【代码实践】focal loss损失函数及其变形原理详细讲解和图像分割实践(含源码)
3.代码编写
3.1 Log-Dice_Ce Loss 代码实现
class DiceCeLoss(nn.Module):
def __init__(self):
super(DiceCeLoss, self).__init__()
def forward(self, preds, labels):
assert preds.shape == labels.shape, "predict & target shape do not match"
# 使用pytorch的交叉熵损失函数接口
ce_loss = nn.CrossEntropyLoss()
ce_total_loss = ce_loss(preds, labels)
preds = F.softmax(preds, dim=1)
# nums = labels.shape[1]
# 计算dice系数时去掉背景类,因为背景类不是我们预测的目标,如果是,需要加上
preds = preds[:, 1:] # 去掉背景类
labels = labels[:, 1:]
# 将预测和标签转换为一维向量
preds = preds.reshape(-1)
labels = labels.reshape(-1)
# 计算交集
intersection = torch.sum(preds * labels)
# 计算 Dice 系数
dice = (2. * intersection + 1e-8) / \
(torch.sum(preds) + torch.sum(labels) + 1e-8)
if dice >= 1:
dice = 1
# 这里使用的Log函数来加权
dice_ce_loss = -1 * torch.log(dice) + ce_total_loss
return dice_ce_loss
3.2 pytorch具体使用示例
def train(args):
model = TransRes1Unet(1, 10).to(args.device) # 初始化自己模型
batch_size = args.batch_size # 初始化batch-size
criterion = my_loss.DiceCeLoss() # 初始化这里定义的DiceCeLoss
optimizer = optim.Adam(model.parameters(), lr=0.001) # 初始化优化器
# 初始化自己的数据集
ms_dataset = MSDataset(
args.train_data_folder, transform=x_transforms, target_transform=y_transforms)
# 构建dataloader
dataloaders = DataLoader(
ms_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 开始训练
train_model(args, model, criterion, optimizer, dataloaders)
可以看到,在训练代码一开始初始化时使用了focal loss作为损失函数参与模型训练。
for x, y in dataload:
with torch.autocast(device_type='cuda', dtype=torch.float32):
step += 1
inputs = x.to(args.device)
labels = y.to(args.device)
# forward
outputs = model(inputs)
# 这里便使用了初始化好的loss函数来计算loss
loss = criterion(outputs, labels)
scaler.scale(loss).backward() # 反向传播
scaler.step(optimizer=optimizer)
scaler.update()
# zero the parameter gradients
optimizer.zero_grad()
4. 应用Dice loss的分割效果
本文评估了 focal loss 和 LogDiceCeLoss(log 后的 dice loss 结合交叉熵损失函数)
在同一个神经网络结构上的表现。两种组合在本研究采用的数据集上迭代了 100 轮,batch-size 均为 32,平局 Dice 系数计算采用微平均方法。
4.1 Dice系数柱状图
在测试集上的结果表明,focal loss 和 Res1Unet 的组合达到了 0.9419 的 dice 系数,
LogDiceCeLoss 达 到 了 0.9398 的 dice 系 数 , 整 体 上 来 说 , focal loss 的 表 现 强 于LogDiceCeLoss 两个千分点,具体表现在使用 focal loss 的情况下每预测 1000 个像素点要比使用 LogDiceCeLoss 多预测正确 2 个像素点,这是一个很小的差距,因此可以认为 focalloss 和 LogDiceCeLoss 在 Res1Unet 上的总体表现基本相同。由 dice 系数直方图可以看到,focal loss 在优良表现上更好一点,表现在低 dice 预测更少,高 dice 预测更多,但是LogDiceCeLoss 在中端部分表现更加良好,不会出现中低表现聚集的情况。
4.2 不同类别预测结果
由图可以看出,不同类别预测准确度柱状图可见,在上颌窦空腔和上颌窦黏膜的表现上,两者表现基本相同,上颌窦空腔因其特征明显,数据集多,因此分割效果非常好,上颌窦黏膜数据集较少,在整体样本中占比较小,特征也不太明显,因此分割效果中等。在样本占比极其不平衡的二号牙上,LogDiceCeLoss 的表现远强于 focal loss,可以看出 dice loss 平衡正负样本的能力略强于 focal loss。
关注博主,收获更多干货
需要全部源代码或者更多支持的请关注点赞收藏博主并在评论区评论噢!
往期精彩干货
基于mmdetection3d的单目3D目标检测模型,效果远超CenterNet3D
SSH?Termius?一篇文章教你使用远程服务器训练
Jetson nano开机自启动python程序
【代码实践】focal loss损失函数及其变形原理详细讲解和图像分割实践(含源码)