背景
使用PyTorch Dataset组建了以字典为核心的管线式预处理逻辑:
class BaseDataset(Dataset):
...
def _preprocess(self, sample:dict):
for transform in self.pipeline:
sample = transform(sample)
return sample
@abstractmethod
def __getitem__(self, index) -> dict:
...
@abstractmethod
def __len__(self) -> int:
...
class MhaDataset(BaseDataset):
...
def __getitem__(self, index):
return self._preprocess(self.available_series[index])
dataset = MhaDataset(
...,
pipeline = [
LoadMHAFile(),
WindowNorm(),
AutoPad(size=patch_size, dim='3d'),
TypeConvert(key='image', dtype=np.float32),
TypeConvert(key='label', dtype=np.uint8),
BatchAugment(
num_samples = patch_augment,
pipeline = [RandomPatch3D(patch_size=patch_size),
ToTensor(key=['image', 'label'])]
),
]
)
现象
内存使用持续上升,不论是否启用DataLoader多进程都会最终导致OOM。gc.collect()无效。
tracamalloc提示每个预处理步骤对字典进行写入操作时都会出现大量内存占用,并累计增大。
原因
Dataset类中,self.available_series保持了对后续所有字典对象的引用。每一个字典在处理完之后,由DataLoader送出后,由于其本身依旧存在一个被self.available_series的引用,因此不会被自动删除。每一个样本在处理之后,样本字典的体积都会变大,因为字典项因各个预处理过程变多了。
原本的设计理念是一个字典在经过一轮预处理之后就应当送入train loop被一次性使用,然后被自动丢弃即可。
解决方法
对于每个样本字典,需要解除Dataset中self.available_series对其的引用。
class MhaDataset(BaseDataset):
...
def __getitem__(self, index):
return self._preprocess(self.available_series[index].copy())
如上,从Dataset中取出初始字典时,使用copy获得副本而不是引用。
这样,新获得的样本字典并不会被Dataset保持引用,在train_loop对其使用结束之后,其引用计数能够自动降为零,从而被python内核自动删除。
1756

被折叠的 条评论
为什么被折叠?



