《深入剖析 MMEngine 中的 InstanceData》
在计算机视觉和深度学习领域中,高效地管理和操作实例级别的数据是至关重要的。mmengine/structures/instance_data.py
模块中的InstanceData
类为我们提供了一个强大的数据结构来处理这类数据。
一、引入依赖和基础类型定义
首先,代码根据当前设备类型定义了一些基础数据类型,如BoolTypeTensor
和LongTypeTensor
,以确保在不同设备(如 CPU、GPU、NPU、MLU、MUSA)上都能正确处理数据。同时,定义了IndexType
,用于表示在数据索引操作中可能出现的各种类型。
二、自定义数据结构示例(TmpObject
)
在深入理解InstanceData
类之前,代码中给出了一个自定义数据结构TmpObject
的示例,这个结构在后续理解InstanceData
的一些操作中起到了辅助说明的作用。
-
初始化:
class TmpObject:
定义了一个名为TmpObject
的类。在初始化方法__init__
中,它接收一个参数tmp
,并断言这个参数是一个列表。这样确保了创建TmpObject
实例时传入的是合法的数据类型。
-
实现特殊方法:
__len__
:这个方法返回内部列表tmp
的长度。通过实现这个方法,TmpObject
实例可以像内置的序列类型一样使用len()
函数来获取长度。__getitem__
:这个方法允许通过索引来访问TmpObject
实例中的元素。如果输入是整数,会检查其范围,若超出范围则抛出IndexError
。否则,将其转换为切片操作以保持维度一致,并返回对应元素的TmpObject
实例。这样可以实现对TmpObject
进行切片操作,类似于对列表等序列类型的操作。cat
:这是一个静态方法,用于连接多个TmpObject
实例。它接收一个TmpObject
实例列表,检查所有实例的类型正确性后,将它们内部的列表连接起来,并返回一个新的TmpObject
实例。这个方法实现了类似列表拼接的功能,但对于自定义的TmpObject
类型。__repr__
:这个方法用于返回TmpObject
实例的字符串表示,方便在调试和输出时查看其内容。
三、InstanceData 类详解
InstanceData
是一个继承自BaseDataElement
的数据结构,专门用于管理实例级别的数据。
-
属性设置(
__setattr__
和__setitem__
):__setattr__
:这个方法重写了默认的属性设置行为。当尝试设置一个新的属性时,它首先检查属性名是否为特殊的内部属性名(_metainfo_fields
和_data_fields
)。如果是这些特殊属性名,且该属性尚未被设置,则允许设置;否则,抛出AttributeError
,表示这些特殊属性是不可变的。- 对于其他属性名,它要求设置的值必须是一个具有
__len__
属性的对象,即可以获取其长度。如果当前InstanceData
实例已经有数据(长度大于 0),那么新设置的值的长度必须与现有数据的长度一致。这确保了所有数据字段的长度一致,方便后续的操作和处理。 __setitem__
:这个方法被设置为与__setattr__
相同的功能,使得可以通过索引赋值的方式来设置属性。
-
索引操作(
__getitem__
):- 这个方法允许通过多种类型的索引来获取
InstanceData
实例中的数据。 - 参数
item
可以是字符串、整数、列表、切片、numpy.ndarray
、torch.LongTensor
或torch.BoolTensor
等类型。 - 如果
item
是字符串,直接返回对应的属性值。例如,如果有一个属性名为det_labels
,那么instance_data['det_labels']
将返回这个属性的值。 - 如果
item
是整数,会检查其范围。如果超出范围,即item >= len(self) or item < -len(self)
,则抛出IndexError
。否则,将其转换为切片操作以保持维度一致,即将整数转换为slice(item, None, len(self))
,然后继续处理。 - 如果
item
是列表,先将其转换为numpy.ndarray
类型。如果列表中的元素是整数,并且在转换为numpy.ndarray
时是int32
类型(在某些平台上numpy
的默认整数类型是int32
),为了与torch.Tensor
的索引要求一致(通常要求是int64
),会将其转换为int64
类型。然后,将numpy.ndarray
转换为torch.Tensor
类型。 - 如果
item
是torch.Tensor
类型,首先确保它是一维张量,因为这里只支持在第一维度上进行索引操作。如果item
是布尔张量(BoolTypeTensor
),还会检查其形状是否与当前InstanceData
实例的长度一致。 - 对于每个属性,根据其值的类型进行不同的处理:
- 如果是
torch.Tensor
类型,直接使用item
进行索引操作,即v[item]
。 - 如果是
numpy.ndarray
类型,先将item
转换为numpy
数组的索引,即item.cpu().numpy()
,然后进行索引操作。 - 如果是
str
、list
、tuple
类型或者具有__getitem__
和cat
属性的自定义类型:- 如果
item
是布尔张量,将其转换为索引列表,即torch.nonzero(item).view(-1).cpu().numpy().tolist()
。 - 否则,将
item
转换为索引列表,即item.cpu().numpy().tolist()
。 - 根据索引列表创建切片列表,对于空索引列表,创建一个长度为 0 的切片。
- 对于每个切片,从属性值中获取相应的子序列。如果属性值是
str
、list
或tuple
类型,将这些子序列连接起来;如果属性值具有cat
方法,调用cat
方法将这些子序列连接起来。
- 如果
- 如果是
- 最后,返回一个新的
InstanceData
实例,包含索引后的数据。
- 这个方法允许通过多种类型的索引来获取
-
连接操作(
cat
方法):- 这是一个静态方法,用于连接多个
InstanceData
实例。 - 首先,它检查输入列表中的所有元素是否都是
InstanceData
类型。如果不是,会抛出异常。 - 然后,确保所有实例具有完全相同的键。它通过获取每个实例的所有键列表,检查这些列表的长度是否一致,并且通过将所有键列表展开为一个集合,检查集合的长度是否与第一个列表的长度一致。如果不满足这些条件,说明实例之间的键不一致,可能会导致连接操作失败,因此抛出异常。
- 接着,对于每个键,根据值的类型进行不同的连接操作:
- 如果是
torch.Tensor
类型,使用torch.cat
在第一维度上进行连接操作,即torch.cat(values, dim=0)
。 - 如果是
numpy.ndarray
类型,使用np.concatenate
在第一维度上进行连接操作,即np.concatenate(values, axis=0)
。 - 如果是
str
、list
或tuple
类型,将所有实例的对应值连接起来,即new_values = v0[:]
,然后对于每个后续的值,进行连接操作,即new_values += v
。 - 如果具有
cat
属性,调用其cat
方法进行连接操作,即new_values = v0.cat(values)
。
- 如果是
- 最后,返回一个新的
InstanceData
实例,包含连接后的所有数据。
- 这是一个静态方法,用于连接多个
-
长度获取(
__len__
方法):- 这个方法用于获取
InstanceData
实例的长度。 - 如果有数据字段(即
_data_fields
不为空),返回第一个非空数据字段的长度。这里假设所有数据字段的长度是一致的,所以只需要获取第一个非空字段的长度即可。 - 如果没有数据字段,返回 0。
- 这个方法用于获取
四、使用示例
代码最后给出了一个详细的使用示例,展示了如何创建InstanceData
实例、设置各种属性、进行索引操作和连接操作。
-
创建
InstanceData
实例并设置元信息:img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
创建了一个包含图像形状和填充形状的元信息字典。instance_data = InstanceData(metainfo=img_meta)
创建了一个InstanceData
实例,并将元信息设置为img_meta
。
-
设置各种数据字段:
instance_data.det_labels = torch.LongTensor([2, 3])
设置了检测标签属性为一个长整型张量。instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
设置了检测分数属性为一个张量。instance_data.bboxes = torch.rand((2, 4))
设置了边界框属性为一个随机生成的张量。instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
设置了多边形属性为一个自定义的TmpObject
实例。
-
索引操作:
sorted_results = instance_data[instance_data.det_scores.sort().indices]
根据检测分数进行排序,并使用排序后的索引对InstanceData
实例进行索引操作,得到一个新的实例,其中的数据按照检测分数排序。print(instance_data[instance_data.det_scores > 0.75])
筛选出检测分数大于 0.75 的数据,创建一个新的InstanceData
实例并打印。print(instance_data[instance_data.det_scores > 1])
筛选出检测分数大于 1 的数据,由于没有满足条件的数据,创建一个空的InstanceData
实例并打印。
-
连接操作:
print(instance_data.cat([instance_data, instance_data]))
将两个InstanceData
实例连接起来,创建一个新的实例并打印。
总之,mmengine/structures/instance_data.py
中的InstanceData
类为我们在计算机视觉任务中处理实例级别的数据提供了强大而灵活的工具。通过其丰富的功能和灵活的索引、连接操作,可以方便地管理和处理各种类型的实例数据。