Conditional DETR解读---带anchor的DETR

DETR存在的问题

1.收敛速度慢

2.对小目标物体检测效果不好,因为transformer计算量大,受限于计算规模,CNN提取特征时只采取了最后一层特征,没有用FPN等结构。所以对于小目标检测效果不好。

论文主要观点

  • 通过对DETRdecoder中的attentionmap进行可视化,发现query查询到的区域都是物体的extremity末端区域。所以论文认为attention尝试找到物体的边界区域。

  • 论文中认为DETRtransofmer结构中的信息主要可以分为两部分,一部分是与图像的特征(颜色纹理等)相关的信息,称为content,比如encoder或decoder的输出信息。另一部分是代表空间上的信息,称为spatial,比如position embedding等。

  • detr中的CNN与encoder只涉及图像特征向量提取;decoder中的self-attn只涉及query之间的交互去重;所以收敛慢的最可能原因发生在cross attn

  • Cross attention中的K包含encoder输出信息(content key Ck)与position embedding(spatial Key Pk),Q包含self attention的输出(content query Cq)和object query(spatial query Pq)信息。论文中发现去掉cross attention中的object基本不掉点,所以收敛慢很可能是content query难学习导致的。

  • 提出了reference point的概念,为每个query设定一个检测范围,使得匹配更加稳定,加快了收敛

  • 原始detr混合两者学习,使得content query难学习。所以将content与spatial进行解耦

在这里插入图片描述

变为

在这里插入图片描述

网络结构

在这里插入图片描述

对于object query生成了一个2D坐标embedding(上图中的s),用于限定当前query的预测范围。最终decoder的输出的是相对与s的偏移量

bbox回归输出

在这里插入图片描述

其中f是decoer的输出,S表示x,y的坐标。最终b是[x,y,w,h]的向量。

classifier分类输出

在这里插入图片描述

f是decoder的输出,输出每个候选框的类别

decoder Pq生成:

提出了reference point的概念,即图中的s,是一个2d的坐标(q_num,B,2),由object queries经过一个线性层生成,代表了每个query的预测范围。

s经过sigmoid和position embedding后(图中的Ps),跟FFN(decoder embedding)(即图中的T)做内积。得到空间特征Pq

在这里插入图片描述

在这里插入图片描述

代码spatial query这一部分的实现:

# query_pos [num_query,batch,d_model]
# reference_points_before_sigmoid [num_query,batch,2]  从query预测一个坐标,代表了这个query预测的大概范围
reference_points_before_sigmoid = self.ref_point_head(query_pos)    # [num_queries, batch_size, 2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1
### 条件DETR模型的优化与改进 条件DETR是一种基于Transformer架构的目标检测方法,它通过引入条件解码器来增强目标查询的质量。为了进一步提升其性能和效率,可以从以下几个方面考虑对其进行改进: #### 1. **动态锚框机制** 传统的DETR模型采用固定数量的目标查询(object queries),这可能导致冗余计算或无法适应不同场景下的目标分布。可以借鉴YOLOv5等算法中的动态锚框思想,在训练过程中自适应调整目标查询的数量和位置[^1]。 #### 2. **多尺度特征融合** 当前版本的条件DETR主要依赖单层或多层卷积网络提取图像特征。然而,对于复杂背景或者小尺寸物体而言,单一层次的信息可能不足以提供足够的上下文支持。因此可以通过FPN(Feature Pyramid Network)或其他先进的多尺度特征融合技术加强跨层级间信息交互能力[^2]。 #### 3. **轻量化设计** 尽管取得了良好效果,但原版DETR及其变体通常具有较高的参数量以及推理延迟时间。针对资源受限设备上的应用需求,可探索剪枝、量化压缩等方式降低整体开销;同时也可以尝试更高效的注意力实现形式比如Linformer,Sinkhorn等等[^3]。 #### 4. **强化学习策略** 利用强化学习框架指导anchor point的选择过程可能会来额外增益。具体来说就是让agent学会根据不同输入样本特性灵活决定最佳候选区域设置方案从而减少不必要的搜索空间范围并提高定位精度[^4]。 ```python import torch.nn as nn class OptimizedConditionalDETR(nn.Module): def __init__(self, backbone, transformer, num_classes=91): super(OptimizedConditionalDETR,self).__init__() self.backbone = backbone # 使用更强力的骨干网结构如Swin Transformer代替ResNet系列 self.transformer = transformer ... # 其他初始化操作保持不变 def forward(self,x): ... ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值