解决mmsegmentation日志无法输出验证损失

一、新版更新mmengine、mmcv后无法正常输出验证损失

1、继续使用最新版参考[Feature] Support calculating loss during validation #1503

  • 因为mmengine目前没有可以选择输入验证损失的入口。根据mmengine/runner/loops.py中该PR提交的代码需要将验证损失计算放在(model.forwardmode='predict'predict)中,并将通过将val_loss添加到BaseDataElement中,输出在其他参数之后。
  • mmsegmentation将模式的选择放在mmseg/models/segmentors/base.py的forward函数中,可以选择在这里的predict中添加代码
        if mode == 'loss':
            return self.loss(inputs, data_samples)
        elif mode == 'predict':
            return self.predict(inputs, data_samples)
        elif mode == 'tensor':
            return self._forward(inputs, data_samples)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')
  • 我之前因为修改过程中有点问题直接在mmseg/models/segmentors/encoder_decoder.py中进行了修改(建议在base.py中修改)
  •     def predict(self,
                    inputs: Tensor,
                    data_samples: OptSampleList = None) -> SampleList:
            """Predict results from a batch of inputs and data samples with post-
            processing.
    
            Args:
                inputs (Tensor): Inputs with shape (N, C, H, W).
                data_samples (List[:obj:`SegDataSample`], optional): The seg data
                    samples. It usually includes information such as `metainfo`
                    and `gt_sem_seg`.
    
            Returns:
                list[:obj:`SegDataSample`]: Segmentation results of the
                input images. Each SegDataSample usually contain:
    
                - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
                - ``seg_logits``(PixelData): Predicted logits of semantic
                    segmentation before normalization.
            """
            if data_samples is not None:
                batch_img_metas = [
                    data_sample.metainfo for data_sample in data_samples
                ]
            else:
                batch_img_metas = [
                    dict(
                        ori_shape=inputs.shape[2:],
                        img_shape=inputs.shape[2:],
                        pad_shape=inputs.shape[2:],
                        padding_size=[0, 0, 0, 0])
                ] * inputs.shape[0]
    
            seg_logits = self.inference(inputs, batch_img_metas)
            predictions = self.postprocess_result(seg_logits, data_samples)
    
            # 添加验证损失
            if data_samples is not None and all(hasattr(ds, 'gt_sem_seg') for ds in data_samples): # 避免测试过程被调用
                losses = self.loss(inputs, data_samples)  # 调用 loss 方法计算损失
                val_loss = BaseDataElement(loss=losses)
                return predictions + [val_loss]  # 将损失附加在结果后
            return predictions
    
            # return self.postprocess_result(seg_logits, data_samples)

2、降级mmcv1.x版本参考Fix validation loss logging #1494

内置验证损失

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

fulangxu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值