目标检测7-DETR算法剖析与实现


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


端到端目标检测框架DETR

背景介绍

DETRFacebook AINicolas Carion等于202005月提交的论文中提出的。

论文地址: https://arxiv.org/abs/2005.12872
开源代码: https://github.com/facebookresearch/detr

DETR(DEtection TRansformer)将目标检测问题看成是集合预测的问题,所谓集合预测set prediction是指一次输出一张图像中的所有待检测对象。

DETR使用transformer来做目标检测,直接预测检测框到检测框中心点归一化的距离。在模型训练时,Proposal Assignment使用的算法是一对一的匈牙利算法,通过query的方式获取最后的输出。以上介绍的策略,使得DETR实现了目标检测算法的端到端训练,不需要使用NMS和先验anchor

模型结构

从上面这个图可以看到DETR的架构相当简单,输入一张图像,直接输出的就是所有的检测框,不需要复杂的编解码,不需要NMS

模块解析

数据

官方源码中数据定义在CocoDetection类中,这个类继承自torchvision.datasets.CocoDetection只需要传入COCO格式数据集的图像和json标注文件即可,

COCO格式数据集文件夹路径:

.
├── annotations
│   ├── train.json
│   └── val.json
└── images
    ├── train
    └── val

其中,标签文件bounding box的格式为:

left top width height

CoCoDetection类中有一个self.prepare属性,这是一个函数,其中会将ltwh格式的检测框变换成x1y1x2y2格式的检测框。

DETR源码中使用的变换函数不是从torchvision中导入的,而是自定义的,可以看到在Normalize中,不仅处理了图像数据,还将检测框从x1y1x2y2格式变换成了cxcywh格式,并相对于图像的宽高进行了归一化,其值变换到了[0,1]

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target=None):
        image = F.normalize(image, mean=self.mean, std=self.std)
        if target is None:
            return image, None
        target = target.copy()
        h, w = image.shape[-2:]
        if "boxes" in target:
            boxes = target["boxes"]
            boxes = box_xyxy_to_cxcywh(boxes)
            boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
            target["boxes"] = boxes
        return image, target

模型结构

DETR的模型结构其实很简单,先是将图像输入到几层卷积神经网络中得到特征图feature map,然后使用src = src.flatten(2).permute(2, 0, 1)将特征图WH维度拉平将图像变换成长度为L=W*H的序列数据。

根据序列的长度和每个Token的通道数生成位置编码。

feature map生成的序列和位置编码信息相加作为transformer的输入src

除了输入的特征序列之外,还输入了图像数据的掩码src_mask。原因是因为一个batch输入的图像宽高不一定相同,源码中的处理方式是取一个batch中尺寸最大的图像尺寸,其余图像往右下方向补0,最后变成尺寸一致的图像用于计算。这是为了避免padding-0参与计算,需要将src_mask输入到transformer中。

DETR使用的位置编码是针对图像的带mask的二维位置编码

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            <
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值