〖MMDetection〗解析文件:mmdet/models/task_modules/samplers/random_sampler.py

深入解析 MMDetection 中的随机采样器(RandomSampler)

在目标检测任务中,数据采样是一个关键环节,它对模型的训练效果有着重要影响。MMDetection 提供了多种采样器,其中RandomSampler以其简单有效的随机采样策略备受关注。本文将详细解析这个模块的代码。

一、模块概述

RandomSampler是 MMDetection 中的一种采样器,继承自BaseSampler。它的主要作用是随机选取一定数量的正样本和负样本,用于模型的训练。

二、代码解析

1. 导入模块和定义类

from typing import Union

import torch
from numpy import ndarray
from torch import Tensor

from mmdet.registry import TASK_UTILS
from..assigners import AssignResult
from.base_sampler import BaseSampler

这里导入了必要的模块,包括类型提示的工具、PyTorch 的相关模块、NumPy 的数组类型、MMDetection 中的注册器、分配结果模块和基础采样器模块。

@TASK_UTILS.register_module()
class RandomSampler(BaseSampler):
    # 构造函数参数说明和初始化
    def __init__(self,
                 num: int,
                 pos_fraction: float,
                 neg_pos_ub: int = -1,
                 add_gt_as_proposals: bool = True,
                 **kwargs):
        from.sampling_result import ensure_rng
        super().__init__(
            num=num,
            pos_fraction=pos_fraction,
            neg_pos_ub=neg_pos_ub,
            add_gt_as_proposals=add_gt_as_proposals)
        self.rng = ensure_rng(kwargs.get('rng', None))

在构造函数中,接收了一些关键参数,如采样数量num、正样本比例pos_fraction、负样本和正样本数量上界neg_pos_ub以及是否添加真实框作为建议框add_gt_as_proposals。通过调用父类的构造函数进行初始化,并使用ensure_rng函数确保随机数生成器的存在。

2. random_choice函数

    def random_choice(self, gallery: Union[Tensor, ndarray, list],
                      num: int) -> Union[Tensor, ndarray]:
        # 判断 gallery 长度是否满足采样数量要求
        assert len(gallery) >= num
        # 判断 gallery 类型并进行处理
        is_tensor = isinstance(gallery, torch.Tensor)
        if not is_tensor:
            if torch.cuda.is_available():
                device = torch.cuda.current_device()
            else:
                device = 'cpu'
            gallery = torch.tensor(gallery, dtype=torch.long, device=device)
        # 使用临时修复的方法获取随机索引
        perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
        rand_inds = gallery[perm]
        # 如果不是 Tensor 类型,转换为 NumPy 数组
        if not is_tensor:
            rand_inds = rand_inds.cpu().numpy()
        return rand_inds

一、函数概述

这个函数名为random_choice,其目的是从给定的gallery中随机选择一些元素。gallery可以是 PyTorch 的张量(Tensor)、NumPy 的数组(ndarray)或 Python 的列表(list),函数会根据输入的类型返回相应类型的结果,即如果输入是张量则返回张量,否则返回 NumPy 数组。

二、参数说明

  • gallery (Tensor | ndarray | list):这是一个索引池,可以从中随机选择元素。它可以是三种不同的数据类型之一。
  • num (int):期望采样的数量,表示要从gallery中随机选择出多少个元素。

三、函数执行过程

  1. 长度检查

    assert len(gallery) >= num
    

    首先确保gallery的长度至少要大于等于期望采样的数量num,如果不满足这个条件,程序会抛出一个断言错误。

  2. 判断输入类型

    is_tensor = isinstance(gallery, torch.Tensor)
    

    判断输入的gallery是否是 PyTorch 的张量类型,如果是,则is_tensorTrue,否则为False

  3. 非张量类型的处理
    如果gallery不是张量类型,会进行以下处理:

    if not is_tensor:
        if torch.cuda.is_available():
            device = torch.cuda.current_device()
        else:
            device = 'cpu'
        gallery = torch.tensor(gallery, dtype=torch.long, device=device)
    
    • 首先检查是否有可用的 GPU,如果有,则获取当前的 GPU 设备编号;如果没有,则将设备设置为cpu
    • 然后将gallery转换为 PyTorch 的张量类型,数据类型设置为长整型(torch.long),并放置在相应的设备上。
  4. 随机采样

    perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
    rand_inds = gallery[perm]
    
    • torch.randperm(gallery.numel())会生成一个从 0 到gallery总元素数量减 1 的随机排列。然后取这个随机排列的前num个元素,得到一个长度为num的随机索引序列perm
    • 接着,使用这个随机索引序列从gallery中选取相应的元素,得到随机采样的结果rand_inds
  5. 结果类型转换

    if not is_tensor:
        rand_inds = rand_inds.cpu().numpy()
    

    如果输入的gallery不是张量类型,那么将采样结果从 PyTorch 的张量转换为 NumPy 数组。

  6. 返回结果

    return rand_inds
    

    最后,函数返回随机采样的结果,其类型与输入的gallery类型相对应。

3. _sample_pos函数

    def _sample_pos(self, assign_result: AssignResult, num_expected: int,
                    **kwargs) -> Union[Tensor, ndarray]:
        # 获取正样本的索引
        pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
        if pos_inds.numel()!= 0:
            pos_inds = pos_inds.squeeze(1)
        # 根据正样本数量与期望数量的关系进行采样
        if pos_inds.numel() <= num_expected:
            return pos_inds
        else:
            return self.random_choice(pos_inds, num_expected)

一、函数概述

这个函数名为_sample_pos,其作用是随机采样一些正样本。它接收分配结果和期望的正样本数量作为参数,并返回采样后的正样本索引,索引的类型可以是 PyTorch 的张量(Tensor)或 NumPy 的数组(ndarray)。

二、参数说明

  • assign_result (:obj:AssignResult):这是一个分配结果对象,通常包含了关于边界框分配的信息,例如哪些样本被分配为正样本(与真实边界框关联)、哪些被分配为负样本等。
  • num_expected (int):期望采样的正样本数量。

三、函数执行过程

  1. 获取正样本索引

    pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
    

    这里使用torch.nonzero函数从assign_result.gt_inds中找出值大于 0 的元素的索引。assign_result.gt_inds可能是一个张量,其中大于 0 的元素表示对应的样本被分配为正样本。

  2. 处理索引张量

    if pos_inds.numel()!= 0:
        pos_inds = pos_inds.squeeze(1)
    

    如果获取到的正样本索引张量pos_inds不为空(即存在正样本),则使用squeeze(1)函数对其进行处理。这一步可能是为了去除多余的维度,使索引张量的形状更加简洁。

  3. 根据正样本数量进行采样

    if pos_inds.numel() <= num_expected:
        return pos_inds
    else:
        return self.random_choice(pos_inds, num_expected)
    
    • 如果正样本索引的数量小于或等于期望的正样本数量,直接返回这些正样本索引。
    • 如果正样本索引的数量大于期望的正样本数量,则调用self.random_choice函数从这些正样本索引中随机选择num_expected个索引并返回。

4. _sample_neg函数

    def _sample_neg(self, assign_result: AssignResult, num_expected: int,
                    **kwargs) -> Union[Tensor, ndarray]:
        # 获取负样本的索引
        neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
        if neg_inds.numel()!= 0:
            neg_inds = neg_inds.squeeze(1)
        # 根据负样本数量与期望数量的关系进行采样
        if len(neg_inds) <= num_expected:
            return neg_inds
        else:
            return self.random_choice(neg_inds, num_expected)

一、函数概述

这个函数名为_sample_neg,其主要功能是随机采样一些负样本。它接收分配结果和期望的正样本数量(在这个上下文中可能被用作对负样本数量的一种间接参考)作为参数,并返回采样后的负样本索引,索引的类型可以是 PyTorch 的张量(Tensor)或 NumPy 的数组(ndarray)。

二、参数说明

  • assign_result (:obj:AssignResult):分配结果对象,包含了边界框分配的信息,用于确定哪些样本是负样本。
  • num_expected (int):期望的正样本数量,虽然在函数注释中是“期望正样本数量”,但在这个函数中可能被用作对负样本数量的参考或者限制。

三、函数执行过程

  1. 获取负样本索引

    neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
    

    这里使用torch.nonzero函数从assign_result.gt_inds中找出值等于 0 的元素的索引。在这个上下文中,assign_result.gt_inds中值为 0 的元素对应的样本被认为是负样本。

  2. 处理索引张量

    if neg_inds.numel()!= 0:
        neg_inds = neg_inds.squeeze(1)
    

    如果获取到的负样本索引张量neg_inds不为空(即存在负样本),则使用squeeze(1)函数对其进行处理,可能是为了去除多余的维度,使索引张量的形状更加简洁。

  3. 根据负样本数量进行采样

    if len(neg_inds) <= num_expected:
        return neg_inds
    else:
        return self.random_choice(neg_inds, num_expected)
    
    • 如果负样本索引的数量小于或等于期望的正样本数量(这里可能是因为没有明确的负样本数量参数,而用期望正样本数量作为参考),直接返回这些负样本索引。
    • 如果负样本索引的数量大于期望的正样本数量(同样作为参考),则调用self.random_choice函数从这些负样本索引中随机选择num_expected个索引并返回。

三、结语

RandomSampler通过随机采样的方式为目标检测模型提供了一种简单而有效的数据采样策略。它可以根据给定的参数随机选取正样本和负样本,有助于平衡数据集中不同类别的样本数量,提高模型的训练效果。理解这个模块的代码有助于我们更好地掌握 MMDetection 中的数据采样机制,并在实际应用中根据需求进行调整和扩展。

(IS) nie@nie-G3-3590:~/IS-Fusion$ bash tools/train_demo.sh Traceback (most recent call last): File "tools/train.py", line 17, in <module> from mmdet3d.apis import train_model File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet3d/apis/__init__.py", line 2, in <module> from .inference import (convert_SyncBN, inference_detector, File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet3d/apis/inference.py", line 12, in <module> from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes, Coord3DMode, File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet3d/core/__init__.py", line 2, in <module> from .anchor import * # noqa: F401, F403 File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet3d/core/anchor/__init__.py", line 2, in <module> from mmdet.core.anchor import build_prior_generator File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet/core/__init__.py", line 3, in <module> from .bbox import * # noqa: F401, F403 File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet/core/bbox/__init__.py", line 8, in <module> from .samplers import (BaseSampler, CombinedSampler, File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet/core/bbox/samplers/__init__.py", line 12, in <module> from .score_hlr_sampler import ScoreHLRSampler File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmdet/core/bbox/samplers/score_hlr_sampler.py", line 3, in <module> from mmcv.ops import nms_match File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmcv/ops/__init__.py", line 2, in <module> from .active_rotated_filter import active_rotated_filter File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmcv/ops/active_rotated_filter.py", line 10, in <module> ext_module = ext_loader.load_ext( File "/home/nie/anaconda3/envs/IS/lib/python3.8/site-packages/mmcv/utils/ext_loader.py", line 13, in load_ext ext = importli
04-03
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值