【NIPS2023】Rank-DETR for High Quality Object Detection
机构:清华大学、北京大学、剑桥大学、微软亚洲研究院
论文地址:https://arxiv.org/abs/2310.08854
代码地址:https://github.com/LeapLabTHU/Rank-DETR
作者简介:黄高,清华大学博士学位,康奈尔大学计算机系博士后,清华大学自动化系助理教授、博士生导师,获阿里巴巴“达摩院青橙奖”、2019年吴文俊人工智能优秀青年奖等。代表作DenseNet获得CVPR2017年最佳论文、Stochastic Depth。研究方向包括动态神经网络、高效深度学习。
本文考虑到DETR模型中query的重要性存在差异,致力于改进高IoU情况下(例如AP@75)的检测性能,首次提出基于排序思想的Rank-DETR,在Transformer中引入排序相关的网络层、排序导向的损失函数和匈牙利匹配损失。在COCO数据集上的性能高于DINO、Align-DETR、GroupDETR等baseline,与Stable-DINO、MS-DETR、Salience-DETR相当,弱于DDQ-DETR、Co-DETR、Relation-DETR等SOTA方法。
文章贡献/创新点
- 在Transformer Decoder中提出了基于rank机制改进的分类头和query排序层。
- 在损失函数(网络损失和匈牙利匹配损失)中对分类和回归分支进行对齐,使得高置信度的query也具有高IoU。
- 实验验证了所提方法的有效性,并将rank机制引入到已有DETR方法中验证了有效性。
排序相关的结构设计
排序自适应的分类头
常规的DETR方法中,backbone提取多尺度特征,transformer将其映射为6层Decoder输出(两阶段方法还会多1层Encoder输出),每层的输出都是 n n n个query(原始DETR中 n = 100 n=100 n=100、DeformableDETR中 n = 300 n=300 n=300、DINO中 n = 900 n=900 n=900),针对第 l l l层的每个query,表示为 q i l q_i^l qil,head将其映射为分类结果 p i l \boldsymbol p_i^l pil+回归结果,其中分类头是单层全连接:
p i l = S i g m o i d ( r i l ) , t i l = MLP cls ( q i l ) \boldsymbol p_i^l=\mathrm{Sigmoid}(\boldsymbol r_i^l), \boldsymbol t_i^l=\text{MLP}_\text{cls}(\boldsymbol q_i^l) pil=Sigmoid(ril),til=MLPcls(qil)
本文提出的排序自适应分类头其实就是为每个query增加了对应的embedding,两者加起来再进行分类:
p i l = S i g m o i d ( t i l + s i l ) , t i l = MLP cls ( q i l ) \boldsymbol p_i^l=\mathrm{Sigmoid}(\boldsymbol t_i^l+\boldsymbol s_i^l), \boldsymbol t_i^l=\text{MLP}_\text{cls}(\boldsymbol q_i^l) pil=Sigmoid(til+sil),til=MLPcls(qil)
这里的 s i l \boldsymbol s_i^l sil表示第 l l l层第 i i i个query的embedding。所有的 s \boldsymbol s s都作为网络参数自适应学习。
Query排序层
Transformer Decoder输入的query包含两部分,分别是content query和position query。其中content query作为网络可以学习的参数进行初始化,position query来自Transformer Encoder输出的候选框。RankDETR在每个Transformer解码层后增加了一个query rank layer来对这两种query进行排序。
对内容query进行排序
排序依据是每层输出的分类结果 P ^ l − 1 = M L P cls ( Q c l − 1 ) \hat{\mathcal P}^{l-1}=\mathrm{MLP}_\text{cls}(\mathcal Q_c^{l-1}) P^l−1=MLPcls(Qcl−1),排序层会按照该置信度对内容query Q c \mathcal Q_c Qc进行降序排序,排序后的 Q ^ c l \hat{\mathcal Q}_c^l Q^cl与 C l \mathcal C^l Cl进行拼接,这里的 C l \mathcal C^l Cl同样是作为网络参数自适应学习。拼接后的结果在channel维度就变成了2倍,因此再经过全连接进行降维,得到下一层的query。整个流程:
Q ‾ c l − M L P fuse ( Q ^ c l − 1 ∣ ∣ C l ) , Q ^ c l − 1 = S o r t ( Q c l − 1 ; P ^ l − 1 ) \overline{\mathcal Q}_c^l-\mathrm{MLP}_\text{fuse}(\hat{\mathcal Q}_c^{l-1}||\mathcal C^l), \hat{\mathcal Q}_c^{l-1}=\mathrm{Sort}(\mathcal Q_c^l-1;\hat{\mathcal P}^{l-1}) Qcl−MLPfuse(Q^cl−1∣∣Cl),Q^cl−1=Sort(Qcl−1;P^l−1)
对位置query进行排序
位置query的排序方式在各个DETR方法中有所不同,针对H-DETR和Deformable DETR,作者也是按照 P ^ l − 1 \hat{\mathcal P}^{l-1} P^l−1对每一层位置query进行降序排序:
Q ‾ p l = S o r t ( Q ‾ p l − 1 ; P ^ l − 1 ) \overline{\mathcal Q}_p^l=\mathrm{Sort}(\overline{\mathcal Q}_p^{l-1};\hat{\mathcal P}^{l-1}) Qpl=Sort(Qpl−1;P^l−1)
由于DINO-DETR的位置query是上一层的位置query经过bounding box微调的结果,作者没有直接对 Q ^ p l \hat{\mathcal Q}_p^l Q^pl进行排序,而是先对检测框进行排序,再由排序后的检测框生成下一层的query:
Q ‾ p l = P E ( B ‾ l − 1 ) , B ‾ l − 1 = S o r t ( B l − 1 ; P ^ l − 1 ) \overline{\mathcal Q}_p^l=\mathrm{PE}(\overline{\mathcal B}^{l-1}),\overline{B}^{l-1}=\mathrm{Sort}(\mathcal B^{l-1};\hat{\mathcal P}^{l-1}) Qpl=PE(Bl−1),Bl−1=Sort(Bl−1;P^l−1)
其中 P E \mathrm{PE} PE表示正余弦编码和多层全连接。
代码实现
作者主要是基于H-DETR实现的Rank-DETR,因此代码没给出DINO-DETR的排序方式。
# 对内容进行排序
output = torch.gather(
output, 1, rank_indices.unsqueeze(-1).repeat(1, 1, output.shape[-1])
)
# 排序后与C进行拼接,然后经过MLP
concat_term = self.pre_racq_trans[layer_idx - 1](
self.rank_aware_content_query[layer_idx - 1].weight[:output.shape[1]].unsqueeze(0).expand(output.shape[0], -1, -1)
)
output = torch.cat((output, concat_term), dim=2)
output = self.post_racq_trans[layer_idx - 1](output)
# 对未知进行排序
query_pos = torch.gather(
query_pos, 1, rank_indices.unsqueeze(-1).repeat(1, 1, query_pos.shape[-1])
)
# 省略中间代码......
# 获得排序依据:训练时有one2one query和one2many query,要分别对排序,推理时只有one2one query
if self.training:
rank_indices_one2one = torch.argsort(rank_basis[:, : self.num_queries_one2one], dim=1, descending=True) # tensor shape: [bs, num_queries_one2one]
rank_indices_one2many = torch.argsort(rank_basis[:, self.num_queries_one2one :], dim=1, descending=True) # tensor shape: [bs, num_queries_one2many]
rank_indices = torch.cat(
(
rank_indices_one2one,
rank_indices_one2many + torch.ones_like(rank_indices_one2many) * self.num_queries_one2one
),
dim=1,
) # tensor shape: [bs, num_queries_one2one+num_queries_one2many]
else:
rank_indices = torch.argsort(rank_basis[:, : self.num_queries_one2one], dim=1, descending=True)
排序相关的损失设计
损失函数设计
一般DETR的损失包括三部分,分类损失、定位损失和GIoU损失:
− λ 1 G I o U ( b ^ , b ) + λ 2 ℓ 1 ( b ^ , b ) + λ 3 F L ( p ^ [ c ] ) -\lambda_1\mathrm{GIoU}(\hat{\boldsymbol b},\boldsymbol b)+\lambda_2\ell_1(\hat{\boldsymbol b},\boldsymbol b)+\lambda_3\mathrm{FL}(\hat{\boldsymbol p}[c]) −λ1GIoU(b^,b)+λ2ℓ1(b^,b)+λ3FL(p^[c])
作者提出的改进就在于其中的 F L \mathrm{FL} FL分类损失上,作者将分类目标从原始的二分类0-1目标替换为了基于 I o U \mathrm{IoU} IoU的分类目标:
F L GIoU ( p ^ [ c ] ) = − ∣ t − p ^ [ c ] ∣ γ ⋅ [ t ⋅ log ( p ^ ) ] + ( 1 − t ) ⋅ log ( 1 − p ^ [ c ] ) \mathrm{FL}^{\text{GIoU}}(\hat{\boldsymbol p}[c])=-|t-\hat{\boldsymbol p}[c]|^\gamma\cdot[t\cdot\log(\hat{\boldsymbol p})]+(1-t)\cdot\log(1-\hat{\boldsymbol p}[c]) FLGIoU(p^[c])=−∣t−p^[c]∣γ⋅[t⋅log(p^)]+(1−t)⋅log(1−p^[c])
其中 t = ( G I o U ( b ^ , b ) + 1 ) / 2 t=(\mathrm{GIoU}(\hat{\boldsymbol b}, \boldsymbol b)+1)/2 t=(GIoU(b^,b)+1)/2。文章还对比了VFL的损失函数:
V F L ( p ^ [ c ] ) = − t ⋅ [ t ⋅ log ( p ^ [ c ] ) + ( 1 − t ) ⋅ log ( 1 − p ^ [ c ] ) ] \mathrm{VFL}(\hat{\boldsymbol p}[c])=-t\cdot[t\cdot\log(\hat{\boldsymbol p}[c])+(1-t)\cdot\log(1-\hat{\boldsymbol p}[c])] VFL(p^[c])=−t⋅[t⋅log(p^[c])+(1−t)⋅log(1−p^[c])]
可以看到,两者区别就在于对正样本监督的权值从 t t t变成了 ∣ t − p ^ [ c ] ∣ γ |t-\hat{\boldsymbol p}[c]|^\gamma ∣t−p^[c]∣γ。
思考:一般来说DETR正样本监督式很稀缺的,因此在设计损失函数的时候(例如VFL和TOOD损失),正样本通常不进行难例挖掘。这里作者对正样本同样进行了类似GFL的难例挖掘设计,可能是考虑到H-DETR本身通过one2many的匹配设计弥补了正样本监督稀缺的问题,在此基础上进行难例挖掘则可能提升性能。
匹配损失函数设计
常规的匹配损失是对分类、回归和IoU损失进行加权,加权比例通常是2、5、2,文章提出了高阶的匹配损失:
L Hungarian high-order = p ^ [ c ] ⋅ IoU α \mathcal L_\text{Hungarian}^\text{high-order}=\hat{\boldsymbol p}[c]\cdot\text{IoU}^\alpha LHungarianhigh-order=p^[c]⋅IoUα
其中 p ^ [ c ] \hat{\boldsymbol p}[c] p^[c]是分类头的输出。在代码实现中,前期仍然是使用常规的匹配损失,后期则是换成作者提出的高阶损失。
实验结果
从消融实验中来看,改进损失函数和匹配损失对性能提升的效果最大:
文章另外对比了自己提出的损失和VFL,发现使用自己提出的排序损失能够达到49.8%的AP,而VFL则只有49.5%AP,说明作者提出的排序损失能够更好地建模boxes之间地距离。
Rank-DETR虽然引入了2个结构上的改进,但都比较轻量化;而损失函数的改进不会对推理性能有影响,因此总体FPS基本没有太大下降:
和主流方法的性能对比实验可以看出,Rank-DETR在1 × \times ×下的性能达到了50.2,高于DINO,但相比其他方法(DDQ-DETR、Relation-DETR、Co-DETR),性能还是弱一些,但胜在对推理速度没有太大影响。