目录
摘要
DETRs with Hybrid Matching针对DETR一对一匹配导致的正样本训练效率低下,并导致大量查询未被有效利用的问题。提出了一种混合匹配策略,在训练过程中结合原始的一对一匹配分支和辅助的一对多匹配分支。该方法允许每个真实标签与多个查询进行匹配,从而增加了正样本的数量,提高了训练效率。在预测过程中,只使用原始的一对一匹配分支,既保持了DETR端到端的优点和相同的推理效率,同时也提高了模型的精度。该方法被命名为H-DETR,在目标检测、实例分割、全景分割、姿态估计、目标跟踪等,都显示出了有效性,并且能够提升一系列DETR方法的性能。H-DETR在COCO数据集上,相比于Deformable-DETR,提升了1.7%的mAP。
Abstract
DETRs with Hybrid Matching addresses the issue of low positive sample training efficiency caused by one-to-one matching in DETR, which leads to a large number of queries not being effectively utilized. DETRs proposes a hybrid matching strategy that combines the original one-to-one matching branch with an auxiliary one-to-many matching branch during the training process. This method allows each ground truth label to match with multiple queries, thereby increasing the number of positive samples and improving training efficiency. During the inference process, only the original one-to-one matching branch is used, which not only maintains the end-to-end advantages of DETR and the same inference efficiency but also enhances the model's accuracy. This method is named H-DETR and has shown effectiveness in tasks such as object detection, instance segmentation, panoptic segmentation, pose estimation, and object tracking, and can improve the performance of a series of DETR methods. H-DETR has improved the mAP by 1.7% on the COCO dataset compared to Deformable-DETR.
DETRs
论文地址:[2207.13080] DETRs with Hybrid Matching
项目地址:H-DETR
我们可以认为DETRs是一种运用一对多匹配的训练方法,可以广泛用于各种代表性DETR方法,如:Deformable-DETR、PETRv2、PETR、TransTrack等,以克服一对一匹配的缺点并提高训练效率。改进效果如下图所示:
在了解DETRs如何改进之前,我们需要先知道Transformer是如何应用到目标检测任务中,可以查看之前有关DETR的博客。
网络结构
DETRs网络结构图,如下图所示:
首先,输入图像I,DETRs通过主干网络和Transformer编码器提取一系列增强的像素嵌入。其次,将上述像素嵌入和一组默认的对象查询嵌入
送入Transformer解码器。第三,DETRs在每个Transformer解码器层之后,使用任务特定的预测头更新对象查询嵌入Q,生成一组独立的预测
。最后,在预测P和真实边界框及标签
之间执行一对一的二分匹配。
DETRs针对于DETR的一对多匹配改进就在查询Q处。
三种混合分支的方案
Hybrid-Branch
一对一查询:;一对多查询:
。
-
One-to-one matching
使用L层Transformer解码器处理第一组查询Q,并分别对每个解码器层的输出进行预测。然后,在每一层上对{预测,真实标签}执行二分匹配,损失函数如下:
P表示由第 l 层Transformer解码器输出的预测结果。
- One-to-many matching
同样,使用L层Transformer解码器处理第二组查询,并得到 L 组预测结果。为了执行一对多匹配,将真实标签重复K次,得到一个扩展的真实目标集合
,
。并在每一层上对{预测,扩展真实标签}执行二分匹配,损失函数如下:
Hybrid-Rpoch
在个周期中使用一对多匹配,在
个周期中使用一对一匹配,查询都使用
。
如Hybrid-Branch可知:
- One-to-one matching
- One-to-many matching
Hybrid-Layer
Hybrid-Layer与前两种方法不同之处在于它是将之前的L层Transformer解码器分为两个部分和
。
- One-to-many matching
前层Transformer解码器输出执行一对多匹配,损失函数如下:
- One-to-one matching
后层执行Transformer解码器输出执行一对一匹配,损失函数如下:
以上三种方法,图中颜色相同的部分参数共享。
二分匹配
采用二分图匹配的形式与ground truth框进行一对一的匹配,就无需非极大值抑制处理。
假设a、b、c点到达X、Y、Z点分别有着不同的代价,而它们分别到达每一点的代价图称为cost matrix。在scipy中的linear-sum-assignment函数能够计算出最优化匹配,使得abc到达XYZ的总价值最小。
我们可以理解为a、b、c代表着N个预测框,而X、Y、Z代表ground truth框。遍历所有预测框和ground truth框计算cost,得到最终的cost matrix。cost计算公式如下所示:
然后,利用scipy中的linear-sum-assignment函数计算出cost matrix的最优化匹配。这样就实现了预测框和真实框的一对一匹配,没有出现冗余的框。
最后,在将预测框和真实框进行类别预测和框预测的损失计算,即可反向传播优化模型。损失函数公式如下所示:
通过以上方法结合一对一匹配方案和一对多匹配方案的优势,其中一对一匹配对于去除NMS是必要的,而一对多匹配则增加了与真实标签匹配的查询数量,提高了训练效率。
消融实验
在COCO验证集上评估,H-Deformable-DETR在COCO验证集上达到了59.4%的AP,超过了DINO-DETR方法,以及其他表现歌更好的方法。如下图所示:
代码
混合匹配损失函数:
骨干网络采用ResNet-50,该骨干网络在ImageNet上预训练,DETRs模型训练PyTorch代码如下:
#-------------------------------------#
# 对数据集进行训练
#-------------------------------------#
import datetime
import os
from functools import partial
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from nets.detr import DETR
from nets.detr_training import (build_loss, get_lr_scheduler, set_optimizer_lr,
weights_init)
from utils.callbacks import EvalCallback, LossHistory
from utils.dataloader import DetrDataset, detr_dataset_collate
from utils.utils import (get_classes, seed_everything, show_config,
worker_init_fn)
from utils.utils_fit import fit_one_epoch
if __name__ == "__main__":
#---------------------------------#
# Cuda 是否使用Cuda
# 没有GPU可以设置成False
#---------------------------------#
Cuda = True
#----------------------------------------------#
# Seed 用于固定随机种子
# 使得每次独立训练都可以获得一样的结果
#----------------------------------------------#
seed = 11
#---------------------------------------