Softmax vs. Softmax-Loss: Numerical Stability

本文探讨了在神经网络中Softmax与Softmax-Loss层的区别及其对数值稳定性的影响。通过对比分析,说明了将两者结合为一层的优势,尤其是在避免因浮点数运算累积误差而导致的数值不稳定问题。

参考博文《Softmax vs. Softmax-Loss: Numerical Stability》
Softmax vs. Softmax-Loss: Numerical Stability

其中有关键的几段话:

  • 整个网络的参数的 gradient 的计算方法是从顶层出发向后,在 层的时候,会拿到从 得到的 也就是 ,然后需要做两个计算首先是自己层内的参数的 gradient,比如如果是一个普通的全连通内积层,则会有参数 和 bias 参数 ,根据刚才的 Chain Rule 式子直接计算就可以了,如果 层没有参数 (例如 Softmax 或者 Softmax-Loss 层就是没有参数的。),这一步可以省略;其次就是“向后传递”的步骤,同样地,根据 Chain Rule 我们可以计算. 即当对输入数据bottom求导时,仅仅是为了“向后传递”,这个导数并不参与该层的参数的gradient descent 参数更新,因为该层就没有参数,如何来更新

  • 我们回到 Softmax-Loss 层,由于该层没有参数,我们只需要计算向后传递的导数就可以了,此外由于该层是最顶层,所以不用使用 chain rule 就可以直接计算对于最终输出(loss)的导数。回忆一下我们刚才的 notation,Softmax-Loss 层合在一起的时候我们用 来表示,它有两个输入,一个是 true label ,直接来自于最底部的数据层,并且我们不需要对数据层做任何的 gradient descent 参数更新,所以我们不需要像那个输入进行 back propagation,但是另外一个输入 则来自于下面的计算层,对于 Logistic Regression 或者普通的 DNNs 下面会是一个全连通的线性内积层,不过具体是什么我们也不需要关心,只要把 计算出来丢给下面让他们自己去算后面的就好了。

  • 和刚才的结果一样的,看来我们求导没有求错。虽然最终结果是一样的,但是我们可以看出,如果分成两层计算的话,要多算好多步骤,除了计算量增大了一点,我们更关心的是数值上的稳定性。由于浮点数是有精度限制的,每多一次运算就会多累积一定的误差,注意到分成两步计算的时候我们需要计算 这个量,如果碰巧这次预测非常不准, 的值,也就是正确的类别所得到的概率非常小(接近零)的话,这里会有 overflow 的危险。下面我们来实际试验一下,首先定义好两种不同的计算函数:

class DiceLoss(_Loss): def __init__( self, mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, ignore_index: Optional[int] = None, eps: float = 1e-7, ): """Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases Args: mode: Loss mode 'binary', 'multiclass' or 'multilabel' classes: List of classes that contribute in loss computation. By default, all channels are included. log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff` from_logits: If True, assumes input is raw logits smooth: Smoothness constant for dice coefficient (a) ignore_index: Label that indicates ignored pixels (does not contribute to loss) eps: A small epsilon for numerical stability to avoid zero division error (denominator will be always greater or equal to eps) Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) Reference https://github.com/BloodAxe/pytorch-toolbelt """ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super(DiceLoss, self).__init__() self.mode = mode if classes is not None: assert mode != BINARY_MODE, ( "Masking classes is not supported with mode=binary" ) classes = to_tensor(classes, dtype=torch.long) self.classes = classes self.from_logits = from_logits self.smooth = smooth self.eps = eps self.log_loss = log_loss self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: assert y_true.size(0) == y_pred.size(0) if self.from_logits: # Apply activations to get [0..1] class probabilities # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on # extreme values 0 and 1 if self.mode == MULTICLASS_MODE: y_pred = y_pred.log_softmax(dim=1).exp() else: y_pred = F.logsigmoid(y_pred).exp() bs = y_true.size(0) num_classes = y_pred.size(1) dims = (0, 2) if self.mode == BINARY_MODE: y_true = y_true.view(bs, 1, -1) y_pred = y_pred.view(bs, 1, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask if self.mode == MULTICLASS_MODE: y_true = y_true.view(bs, -1) y_pred = y_pred.view(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask.unsqueeze(1) y_true = F.one_hot( (y_true * mask).to(torch.long), num_classes ) # N,H*W -> N,H*W, C y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W else: y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask scores = self.compute_score( y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims ) if self.log_loss: loss = -torch.log(scores.clamp_min(self.eps)) else: loss = 1.0 - scores # Dice loss is undefined for non-empty classes # So we zero contribution of channel that does not have true pixels # NOTE: A better workaround would be to use loss term `mean(y_pred)` # for this case, however it will be a modified jaccard loss mask = y_true.sum(dims) > 0 loss *= mask.to(loss.dtype) if self.classes is not None: loss = loss[self.classes] return self.aggregate_loss(loss) def aggregate_loss(self, loss): return loss.mean() def compute_score( self, output, target, smooth=0.0, eps=1e-7, dims=None ) -> torch.Tensor: return soft_dice_score(output, target, smooth, eps, dims)源代码如下
最新发布
07-28
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值