一、新版更新mmengine、mmcv后无法正常输出验证损失
1、继续使用最新版参考[Feature] Support calculating loss during validation #1503
- 因为mmengine目前没有可以选择输入验证损失的入口。根据mmengine/runner/loops.py中该PR提交的代码需要将验证损失计算放在(
model.forward
用mode='predict'
或predict
)中,并将通过将val_loss添加到BaseDataElement
中,输出在其他参数之后。 - mmsegmentation将模式的选择放在mmseg/models/segmentors/base.py的forward函数中,可以选择在这里的predict中添加代码
if mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
- 我之前因为修改过程中有点问题直接在mmseg/models/segmentors/encoder_decoder.py中进行了修改(建议在base.py中修改)
-
def predict(self, inputs: Tensor, data_samples: OptSampleList = None) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: inputs (Tensor): Inputs with shape (N, C, H, W). data_samples (List[:obj:`SegDataSample`], optional): The seg data samples. It usually includes information such as `metainfo` and `gt_sem_seg`. Returns: list[:obj:`SegDataSample`]: Segmentation results of the input images. Each SegDataSample usually contain: - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. - ``seg_logits``(PixelData): Predicted logits of semantic segmentation before normalization. """ if data_samples is not None: batch_img_metas = [ data_sample.metainfo for data_sample in data_samples ] else: batch_img_metas = [ dict( ori_shape=inputs.shape[2:], img_shape=inputs.shape[2:], pad_shape=inputs.shape[2:], padding_size=[0, 0, 0, 0]) ] * inputs.shape[0] seg_logits = self.inference(inputs, batch_img_metas) predictions = self.postprocess_result(seg_logits, data_samples) # 添加验证损失 if data_samples is not None and all(hasattr(ds, 'gt_sem_seg') for ds in data_samples): # 避免测试过程被调用 losses = self.loss(inputs, data_samples) # 调用 loss 方法计算损失 val_loss = BaseDataElement(loss=losses) return predictions + [val_loss] # 将损失附加在结果后 return predictions # return self.postprocess_result(seg_logits, data_samples)
2、降级mmcv1.x版本参考Fix validation loss logging #1494
内置验证损失