〖MMEngine〗解析文件:mmengine/structures/instance_data.py

《深入剖析 MMEngine 中的 InstanceData》

在计算机视觉和深度学习领域中,高效地管理和操作实例级别的数据是至关重要的。mmengine/structures/instance_data.py模块中的InstanceData类为我们提供了一个强大的数据结构来处理这类数据。

一、引入依赖和基础类型定义

首先,代码根据当前设备类型定义了一些基础数据类型,如BoolTypeTensorLongTypeTensor,以确保在不同设备(如 CPU、GPU、NPU、MLU、MUSA)上都能正确处理数据。同时,定义了IndexType,用于表示在数据索引操作中可能出现的各种类型。

二、自定义数据结构示例(TmpObject

在深入理解InstanceData类之前,代码中给出了一个自定义数据结构TmpObject的示例,这个结构在后续理解InstanceData的一些操作中起到了辅助说明的作用。

  1. 初始化

    • class TmpObject:定义了一个名为TmpObject的类。在初始化方法__init__中,它接收一个参数tmp,并断言这个参数是一个列表。这样确保了创建TmpObject实例时传入的是合法的数据类型。
  2. 实现特殊方法

    • __len__:这个方法返回内部列表tmp的长度。通过实现这个方法,TmpObject实例可以像内置的序列类型一样使用len()函数来获取长度。
    • __getitem__:这个方法允许通过索引来访问TmpObject实例中的元素。如果输入是整数,会检查其范围,若超出范围则抛出IndexError。否则,将其转换为切片操作以保持维度一致,并返回对应元素的TmpObject实例。这样可以实现对TmpObject进行切片操作,类似于对列表等序列类型的操作。
    • cat:这是一个静态方法,用于连接多个TmpObject实例。它接收一个TmpObject实例列表,检查所有实例的类型正确性后,将它们内部的列表连接起来,并返回一个新的TmpObject实例。这个方法实现了类似列表拼接的功能,但对于自定义的TmpObject类型。
    • __repr__:这个方法用于返回TmpObject实例的字符串表示,方便在调试和输出时查看其内容。

三、InstanceData 类详解

InstanceData是一个继承自BaseDataElement的数据结构,专门用于管理实例级别的数据。

  1. 属性设置(__setattr____setitem__

    • __setattr__:这个方法重写了默认的属性设置行为。当尝试设置一个新的属性时,它首先检查属性名是否为特殊的内部属性名(_metainfo_fields_data_fields)。如果是这些特殊属性名,且该属性尚未被设置,则允许设置;否则,抛出AttributeError,表示这些特殊属性是不可变的。
    • 对于其他属性名,它要求设置的值必须是一个具有__len__属性的对象,即可以获取其长度。如果当前InstanceData实例已经有数据(长度大于 0),那么新设置的值的长度必须与现有数据的长度一致。这确保了所有数据字段的长度一致,方便后续的操作和处理。
    • __setitem__:这个方法被设置为与__setattr__相同的功能,使得可以通过索引赋值的方式来设置属性。
  2. 索引操作(__getitem__

    • 这个方法允许通过多种类型的索引来获取InstanceData实例中的数据。
    • 参数item可以是字符串、整数、列表、切片、numpy.ndarraytorch.LongTensortorch.BoolTensor等类型。
    • 如果item是字符串,直接返回对应的属性值。例如,如果有一个属性名为det_labels,那么instance_data['det_labels']将返回这个属性的值。
    • 如果item是整数,会检查其范围。如果超出范围,即item >= len(self) or item < -len(self),则抛出IndexError。否则,将其转换为切片操作以保持维度一致,即将整数转换为slice(item, None, len(self)),然后继续处理。
    • 如果item是列表,先将其转换为numpy.ndarray类型。如果列表中的元素是整数,并且在转换为numpy.ndarray时是int32类型(在某些平台上numpy的默认整数类型是int32),为了与torch.Tensor的索引要求一致(通常要求是int64),会将其转换为int64类型。然后,将numpy.ndarray转换为torch.Tensor类型。
    • 如果itemtorch.Tensor类型,首先确保它是一维张量,因为这里只支持在第一维度上进行索引操作。如果item是布尔张量(BoolTypeTensor),还会检查其形状是否与当前InstanceData实例的长度一致。
    • 对于每个属性,根据其值的类型进行不同的处理:
      • 如果是torch.Tensor类型,直接使用item进行索引操作,即v[item]
      • 如果是numpy.ndarray类型,先将item转换为numpy数组的索引,即item.cpu().numpy(),然后进行索引操作。
      • 如果是strlisttuple类型或者具有__getitem__cat属性的自定义类型:
        • 如果item是布尔张量,将其转换为索引列表,即torch.nonzero(item).view(-1).cpu().numpy().tolist()
        • 否则,将item转换为索引列表,即item.cpu().numpy().tolist()
        • 根据索引列表创建切片列表,对于空索引列表,创建一个长度为 0 的切片。
        • 对于每个切片,从属性值中获取相应的子序列。如果属性值是strlisttuple类型,将这些子序列连接起来;如果属性值具有cat方法,调用cat方法将这些子序列连接起来。
    • 最后,返回一个新的InstanceData实例,包含索引后的数据。
  3. 连接操作(cat方法)

    • 这是一个静态方法,用于连接多个InstanceData实例。
    • 首先,它检查输入列表中的所有元素是否都是InstanceData类型。如果不是,会抛出异常。
    • 然后,确保所有实例具有完全相同的键。它通过获取每个实例的所有键列表,检查这些列表的长度是否一致,并且通过将所有键列表展开为一个集合,检查集合的长度是否与第一个列表的长度一致。如果不满足这些条件,说明实例之间的键不一致,可能会导致连接操作失败,因此抛出异常。
    • 接着,对于每个键,根据值的类型进行不同的连接操作:
      • 如果是torch.Tensor类型,使用torch.cat在第一维度上进行连接操作,即torch.cat(values, dim=0)
      • 如果是numpy.ndarray类型,使用np.concatenate在第一维度上进行连接操作,即np.concatenate(values, axis=0)
      • 如果是strlisttuple类型,将所有实例的对应值连接起来,即new_values = v0[:],然后对于每个后续的值,进行连接操作,即new_values += v
      • 如果具有cat属性,调用其cat方法进行连接操作,即new_values = v0.cat(values)
    • 最后,返回一个新的InstanceData实例,包含连接后的所有数据。
  4. 长度获取(__len__方法)

    • 这个方法用于获取InstanceData实例的长度。
    • 如果有数据字段(即_data_fields不为空),返回第一个非空数据字段的长度。这里假设所有数据字段的长度是一致的,所以只需要获取第一个非空字段的长度即可。
    • 如果没有数据字段,返回 0。

四、使用示例

代码最后给出了一个详细的使用示例,展示了如何创建InstanceData实例、设置各种属性、进行索引操作和连接操作。

  1. 创建InstanceData实例并设置元信息:

    • img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))创建了一个包含图像形状和填充形状的元信息字典。
    • instance_data = InstanceData(metainfo=img_meta)创建了一个InstanceData实例,并将元信息设置为img_meta
  2. 设置各种数据字段:

    • instance_data.det_labels = torch.LongTensor([2, 3])设置了检测标签属性为一个长整型张量。
    • instance_data["det_scores"] = torch.Tensor([0.8, 0.7])设置了检测分数属性为一个张量。
    • instance_data.bboxes = torch.rand((2, 4))设置了边界框属性为一个随机生成的张量。
    • instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])设置了多边形属性为一个自定义的TmpObject实例。
  3. 索引操作:

    • sorted_results = instance_data[instance_data.det_scores.sort().indices]根据检测分数进行排序,并使用排序后的索引对InstanceData实例进行索引操作,得到一个新的实例,其中的数据按照检测分数排序。
    • print(instance_data[instance_data.det_scores > 0.75])筛选出检测分数大于 0.75 的数据,创建一个新的InstanceData实例并打印。
    • print(instance_data[instance_data.det_scores > 1])筛选出检测分数大于 1 的数据,由于没有满足条件的数据,创建一个空的InstanceData实例并打印。
  4. 连接操作:

    • print(instance_data.cat([instance_data, instance_data]))将两个InstanceData实例连接起来,创建一个新的实例并打印。

总之,mmengine/structures/instance_data.py中的InstanceData类为我们在计算机视觉任务中处理实例级别的数据提供了强大而灵活的工具。通过其丰富的功能和灵活的索引、连接操作,可以方便地管理和处理各种类型的实例数据。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值