文章目录
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
端到端目标检测框架DETR
背景介绍
DETR
是Facebook AI
的Nicolas Carion
等于2020
年05
月提交的论文中提出的。
论文地址: https://arxiv.org/abs/2005.12872
开源代码: https://github.com/facebookresearch/detr
DETR(DEtection TRansformer)
将目标检测问题看成是集合预测的问题,所谓集合预测set prediction
是指一次输出一张图像中的所有待检测对象。
DETR
使用transformer
来做目标检测,直接预测检测框到检测框中心点归一化的距离。在模型训练时,Proposal Assignment
使用的算法是一对一的匈牙利算法,通过query
的方式获取最后的输出。以上介绍的策略,使得DETR
实现了目标检测算法的端到端训练,不需要使用NMS
和先验anchor
。
模型结构
从上面这个图可以看到DETR
的架构相当简单,输入一张图像,直接输出的就是所有的检测框,不需要复杂的编解码,不需要NMS
。
模块解析
数据
官方源码中数据定义在CocoDetection
类中,这个类继承自torchvision.datasets.CocoDetection
只需要传入COCO
格式数据集的图像和json
标注文件即可,
COCO
格式数据集文件夹路径:
.
├── annotations
│ ├── train.json
│ └── val.json
└── images
├── train
└── val
其中,标签文件bounding box
的格式为:
left top width height
在CoCoDetection
类中有一个self.prepare
属性,这是一个函数,其中会将ltwh
格式的检测框变换成x1y1x2y2
格式的检测框。
在DETR
源码中使用的变换函数不是从torchvision
中导入的,而是自定义的,可以看到在Normalize
中,不仅处理了图像数据,还将检测框从x1y1x2y2
格式变换成了cxcywh
格式,并相对于图像的宽高进行了归一化,其值变换到了[0,1]
。
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
return image, target
模型结构
DETR
的模型结构其实很简单,先是将图像输入到几层卷积神经网络中得到特征图feature map
,然后使用src = src.flatten(2).permute(2, 0, 1)
将特征图WH
维度拉平将图像变换成长度为L=W*H
的序列数据。
根据序列的长度和每个Token
的通道数生成位置编码。
将feature map
生成的序列和位置编码信息相加作为transformer
的输入src
。
除了输入的特征序列之外,还输入了图像数据的掩码src_mask
。原因是因为一个batch
输入的图像宽高不一定相同,源码中的处理方式是取一个batch
中尺寸最大的图像尺寸,其余图像往右下方向补0,最后变成尺寸一致的图像用于计算。这是为了避免padding-0
参与计算,需要将src_mask
输入到transformer
中。
DETR
使用的位置编码是针对图像的带mask
的二维位置编码,
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
<