项目场景
使用 Stitcher 增强小目标检测效果。在每个 iter 训练结束后,根据小目标贡献的 loss 占总 loss 的比重,决定下一个 iter 的数据是否需要 stitch 。
环境:PyTorch 1.3
问题描述
实现上需要使用 Hook 统计 loss 占比,然后设置 Dataset 的标志位。 Dataset.__getitem__ 会首先查询标志位,决定是否返回 stitch 后的结果。运行时发现标志位设置失败, stitch 一直没有激活。
from mmcv.runner import BaseRunner, Hook, HOOKS
from mmdet.datasets.builder import DATASETS
@HOOKS.register_module()
class SmallFeedbackHook(Hook):
def before_train_epoch(self, runner: BaseRunner):
data_loader: DataLoader = runner.data_loader
dataset: StitchedDataset = data_loader.dataset
dataset.active = False
self.dataset = dataset
def after_train_iter(self, runner: BaseRunner):
ratio = .</

本文探讨了在PyTorch环境下使用Stitcher增强小目标检测的方法。通过自定义Hook来调整训练流程,并分析了DataLoader多进程机制下Dataset状态更新的问题及解决方案。
最低0.47元/天 解锁文章
9554





