问题背景
验证集约10000张高分辨率图片。考虑调试场景,训练/验证阶段batch size均为1。训练阶段显存占用正常,约2GB; 验证阶段显存占用异常,从2GB开始逐渐上升直到Out-Of-Memory.
原因分析
在新版(1.x)的MMSegmentation中,BaseSegmentor(mmseg.models.segmentors.base.BaseSegmentor)里面有一个postprocess_result方法,这个方法的作用是对网络分割的结果进行后处理(resize成与gt相同的尺寸并argmax)。
问题在于,正常来讲计算评价指标(如mIoU)的话,只需要返回argmax后的结果即可,而mmseg同时还返回了argmax之前的原始logits,如下所示:
def postprocess_result(self,
seg_logits: Tensor,
data_samples: OptSampleList = None) -> SampleList:
# ...
# ...
# ...
data_samples[i].set_data({
'seg_logits':
PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})
这就会导致一个问题,如果任务的分割类别比较多(如本文场景中的150类)且分辨率较高(1080P以上),那么这个seg_logits将会占用很大的显存,且不会被自动释放(seg_logits包装在SegDataSample对象中,该对象中还有一些其他重要的元数据)。
解决方案
由于这个seg_logits根本就没用,因此直接将其删除即可。修改部分的代码如下:
def postprocess_result(self,
seg_logits: Tensor,
data_samples: OptSampleList = None) -> SampleList:
# ...
# ...
# ...
del i_seg_logits
torch.cuda.empty_cache()
data_samples[i].set_data({
# 'seg_logits':
# PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})