一例dataset get_item对字典进行管线式预处理时导致的内存泄露

背景

使用PyTorch Dataset组建了以字典为核心的管线式预处理逻辑:

class BaseDataset(Dataset):
    ...

    def _preprocess(self, sample:dict):
        for transform in self.pipeline:
            sample = transform(sample)
        return sample
    
    @abstractmethod
    def __getitem__(self, index) -> dict:
        ...
    
    @abstractmethod
    def __len__(self) -> int:
        ...

class MhaDataset(BaseDataset):
	...
	
    def __getitem__(self, index):
        return self._preprocess(self.available_series[index])

dataset = MhaDataset(
	...,
    pipeline = [
        LoadMHAFile(),
        WindowNorm(),
        AutoPad(size=patch_size, dim='3d'),
        TypeConvert(key='image', dtype=np.float32),
        TypeConvert(key='label', dtype=np.uint8),
        BatchAugment(
            num_samples = patch_augment,
            pipeline = [RandomPatch3D(patch_size=patch_size),
                        ToTensor(key=['image', 'label'])]
        ),
    ]
)

现象

内存使用持续上升,不论是否启用DataLoader多进程都会最终导致OOM。gc.collect()无效。
tracamalloc提示每个预处理步骤对字典进行写入操作时都会出现大量内存占用,并累计增大。

原因

Dataset类中,self.available_series保持了对后续所有字典对象的引用。每一个字典在处理完之后,由DataLoader送出后,由于其本身依旧存在一个被self.available_series的引用,因此不会被自动删除。每一个样本在处理之后,样本字典的体积都会变大,因为字典项因各个预处理过程变多了。

原本的设计理念是一个字典在经过一轮预处理之后就应当送入train loop被一次性使用,然后被自动丢弃即可。

解决方法

对于每个样本字典,需要解除Dataset中self.available_series对其的引用。

class MhaDataset(BaseDataset):
	...
	
    def __getitem__(self, index):
        return self._preprocess(self.available_series[index].copy())

如上,从Dataset中取出初始字典时,使用copy获得副本而不是引用。
这样,新获得的样本字典并不会被Dataset保持引用,在train_loop对其使用结束之后,其引用计数能够自动降为零,从而被python内核自动删除。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值