目录
深入解析 MMDetection 中的随机采样器(RandomSampler)
在目标检测任务中,数据采样是一个关键环节,它对模型的训练效果有着重要影响。MMDetection 提供了多种采样器,其中RandomSampler
以其简单有效的随机采样策略备受关注。本文将详细解析这个模块的代码。
一、模块概述
RandomSampler
是 MMDetection 中的一种采样器,继承自BaseSampler
。它的主要作用是随机选取一定数量的正样本和负样本,用于模型的训练。
二、代码解析
1. 导入模块和定义类
from typing import Union
import torch
from numpy import ndarray
from torch import Tensor
from mmdet.registry import TASK_UTILS
from..assigners import AssignResult
from.base_sampler import BaseSampler
这里导入了必要的模块,包括类型提示的工具、PyTorch 的相关模块、NumPy 的数组类型、MMDetection 中的注册器、分配结果模块和基础采样器模块。
@TASK_UTILS.register_module()
class RandomSampler(BaseSampler):
# 构造函数参数说明和初始化
def __init__(self,
num: int,
pos_fraction: float,
neg_pos_ub: int = -1,
add_gt_as_proposals: bool = True,
**kwargs):
from.sampling_result import ensure_rng
super().__init__(
num=num,
pos_fraction=pos_fraction,
neg_pos_ub=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals)
self.rng = ensure_rng(kwargs.get('rng', None))
在构造函数中,接收了一些关键参数,如采样数量num
、正样本比例pos_fraction
、负样本和正样本数量上界neg_pos_ub
以及是否添加真实框作为建议框add_gt_as_proposals
。通过调用父类的构造函数进行初始化,并使用ensure_rng
函数确保随机数生成器的存在。
2. random_choice
函数
def random_choice(self, gallery: Union[Tensor, ndarray, list],
num: int) -> Union[Tensor, ndarray]:
# 判断 gallery 长度是否满足采样数量要求
assert len(gallery) >= num
# 判断 gallery 类型并进行处理
is_tensor = isinstance(gallery, torch.Tensor)
if not is_tensor:
if torch.cuda.is_available():
device = torch.cuda.current_device()
else:
device = 'cpu'
gallery = torch.tensor(gallery, dtype=torch.long, device=device)
# 使用临时修复的方法获取随机索引
perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
rand_inds = gallery[perm]
# 如果不是 Tensor 类型,转换为 NumPy 数组
if not is_tensor:
rand_inds = rand_inds.cpu().numpy()
return rand_inds
一、函数概述
这个函数名为random_choice
,其目的是从给定的gallery
中随机选择一些元素。gallery
可以是 PyTorch 的张量(Tensor
)、NumPy 的数组(ndarray
)或 Python 的列表(list
),函数会根据输入的类型返回相应类型的结果,即如果输入是张量则返回张量,否则返回 NumPy 数组。
二、参数说明
gallery (Tensor | ndarray | list)
:这是一个索引池,可以从中随机选择元素。它可以是三种不同的数据类型之一。num (int)
:期望采样的数量,表示要从gallery
中随机选择出多少个元素。
三、函数执行过程
-
长度检查:
assert len(gallery) >= num
首先确保
gallery
的长度至少要大于等于期望采样的数量num
,如果不满足这个条件,程序会抛出一个断言错误。 -
判断输入类型:
is_tensor = isinstance(gallery, torch.Tensor)
判断输入的
gallery
是否是 PyTorch 的张量类型,如果是,则is_tensor
为True
,否则为False
。 -
非张量类型的处理:
如果gallery
不是张量类型,会进行以下处理:if not is_tensor: if torch.cuda.is_available(): device = torch.cuda.current_device() else: device = 'cpu' gallery = torch.tensor(gallery, dtype=torch.long, device=device)
- 首先检查是否有可用的 GPU,如果有,则获取当前的 GPU 设备编号;如果没有,则将设备设置为
cpu
。 - 然后将
gallery
转换为 PyTorch 的张量类型,数据类型设置为长整型(torch.long
),并放置在相应的设备上。
- 首先检查是否有可用的 GPU,如果有,则获取当前的 GPU 设备编号;如果没有,则将设备设置为
-
随机采样:
perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) rand_inds = gallery[perm]
torch.randperm(gallery.numel())
会生成一个从 0 到gallery
总元素数量减 1 的随机排列。然后取这个随机排列的前num
个元素,得到一个长度为num
的随机索引序列perm
。- 接着,使用这个随机索引序列从
gallery
中选取相应的元素,得到随机采样的结果rand_inds
。
-
结果类型转换:
if not is_tensor: rand_inds = rand_inds.cpu().numpy()
如果输入的
gallery
不是张量类型,那么将采样结果从 PyTorch 的张量转换为 NumPy 数组。 -
返回结果:
return rand_inds
最后,函数返回随机采样的结果,其类型与输入的
gallery
类型相对应。
3. _sample_pos
函数
def _sample_pos(self, assign_result: AssignResult, num_expected: int,
**kwargs) -> Union[Tensor, ndarray]:
# 获取正样本的索引
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel()!= 0:
pos_inds = pos_inds.squeeze(1)
# 根据正样本数量与期望数量的关系进行采样
if pos_inds.numel() <= num_expected:
return pos_inds
else:
return self.random_choice(pos_inds, num_expected)
一、函数概述
这个函数名为_sample_pos
,其作用是随机采样一些正样本。它接收分配结果和期望的正样本数量作为参数,并返回采样后的正样本索引,索引的类型可以是 PyTorch 的张量(Tensor
)或 NumPy 的数组(ndarray
)。
二、参数说明
assign_result (:obj:
AssignResult)
:这是一个分配结果对象,通常包含了关于边界框分配的信息,例如哪些样本被分配为正样本(与真实边界框关联)、哪些被分配为负样本等。num_expected (int)
:期望采样的正样本数量。
三、函数执行过程
-
获取正样本索引:
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
这里使用
torch.nonzero
函数从assign_result.gt_inds
中找出值大于 0 的元素的索引。assign_result.gt_inds
可能是一个张量,其中大于 0 的元素表示对应的样本被分配为正样本。 -
处理索引张量:
if pos_inds.numel()!= 0: pos_inds = pos_inds.squeeze(1)
如果获取到的正样本索引张量
pos_inds
不为空(即存在正样本),则使用squeeze(1)
函数对其进行处理。这一步可能是为了去除多余的维度,使索引张量的形状更加简洁。 -
根据正样本数量进行采样:
if pos_inds.numel() <= num_expected: return pos_inds else: return self.random_choice(pos_inds, num_expected)
- 如果正样本索引的数量小于或等于期望的正样本数量,直接返回这些正样本索引。
- 如果正样本索引的数量大于期望的正样本数量,则调用
self.random_choice
函数从这些正样本索引中随机选择num_expected
个索引并返回。
4. _sample_neg
函数
def _sample_neg(self, assign_result: AssignResult, num_expected: int,
**kwargs) -> Union[Tensor, ndarray]:
# 获取负样本的索引
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel()!= 0:
neg_inds = neg_inds.squeeze(1)
# 根据负样本数量与期望数量的关系进行采样
if len(neg_inds) <= num_expected:
return neg_inds
else:
return self.random_choice(neg_inds, num_expected)
一、函数概述
这个函数名为_sample_neg
,其主要功能是随机采样一些负样本。它接收分配结果和期望的正样本数量(在这个上下文中可能被用作对负样本数量的一种间接参考)作为参数,并返回采样后的负样本索引,索引的类型可以是 PyTorch 的张量(Tensor
)或 NumPy 的数组(ndarray
)。
二、参数说明
assign_result (:obj:
AssignResult)
:分配结果对象,包含了边界框分配的信息,用于确定哪些样本是负样本。num_expected (int)
:期望的正样本数量,虽然在函数注释中是“期望正样本数量”,但在这个函数中可能被用作对负样本数量的参考或者限制。
三、函数执行过程
-
获取负样本索引:
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
这里使用
torch.nonzero
函数从assign_result.gt_inds
中找出值等于 0 的元素的索引。在这个上下文中,assign_result.gt_inds
中值为 0 的元素对应的样本被认为是负样本。 -
处理索引张量:
if neg_inds.numel()!= 0: neg_inds = neg_inds.squeeze(1)
如果获取到的负样本索引张量
neg_inds
不为空(即存在负样本),则使用squeeze(1)
函数对其进行处理,可能是为了去除多余的维度,使索引张量的形状更加简洁。 -
根据负样本数量进行采样:
if len(neg_inds) <= num_expected: return neg_inds else: return self.random_choice(neg_inds, num_expected)
- 如果负样本索引的数量小于或等于期望的正样本数量(这里可能是因为没有明确的负样本数量参数,而用期望正样本数量作为参考),直接返回这些负样本索引。
- 如果负样本索引的数量大于期望的正样本数量(同样作为参考),则调用
self.random_choice
函数从这些负样本索引中随机选择num_expected
个索引并返回。
三、结语
RandomSampler
通过随机采样的方式为目标检测模型提供了一种简单而有效的数据采样策略。它可以根据给定的参数随机选取正样本和负样本,有助于平衡数据集中不同类别的样本数量,提高模型的训练效果。理解这个模块的代码有助于我们更好地掌握 MMDetection 中的数据采样机制,并在实际应用中根据需求进行调整和扩展。