目标检测算法DINO(DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection)对DN-DETR的去噪训练进行扩展——对GT box进行采用较大的扰动以获得负样本(DN-DETR中对GT box添加了细微扰动只构造了正样本)。接下来结合DINO代码看一下对比去噪训练分支中正、负样本是如何生成的。
用于生成对比去噪正负样本的函数在DINO类的forward方法中被调用,名为prepare_for_dn,代码在项目中的路径为:DINO/models/dino/dn_components.py
一、prepare_for_dn的调用
以下代码是prepare_for_dn()在DINO类的forward方法中的调用,每个参数的含义如下:
if self.dn_number > 0 and targets is not None:
input_query_label, input_query_bbox, attn_mask, dn_meta = \
prepare_for_cdn(
dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale),
training=self.training,
num_queries=self.num_queries,
num_classes=self.num_classes,
hidden_dim=self.hidden_dim,
label_enc=self.label_enc)
- dn_args:去噪训练的相关参数,包括标图像中的标注信息targets、去噪组数量self.dn_number(默认值100)、标签翻转概率self.dn_label_noise_ratio(默认值0.5)、box扰动尺度self.dn_label_noise_ratio(默认值0.4);
- training:是否是训练过程,只有训练过程才有去噪部分;
- num_queries:query的数量,默认值为900,用于生成attention mask,以阻止自注意力操作中去噪部分和匹配部分的信息泄露;
- hidden_dim:嵌入向量的维度,默认为256维;
- label_enc:nn.Embedding(dn_labelbook_size + 1, hidden_dim)初始化的实例,用于将label编码为向量;
二、prepare_for_dn代码解析
代码逻辑主要分为训练和推理两部分,由一个if-else语句构成,之前提到过training参数用于判断是训练过程还是推理过程,推理过程的代码很简单,如下所示:
device = dn_args[0][0]['boxes'].device
# 训练阶段,在此省略
if training:
pass
# 推理阶段
else:
input_query_label = None
input_quer