DINO对比去噪训练代码分析

目标检测算法DINO(DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection)对DN-DETR的去噪训练进行扩展——对GT box进行采用较大的扰动以获得负样本(DN-DETR中对GT box添加了细微扰动只构造了正样本)。接下来结合DINO代码看一下对比去噪训练分支中正、负样本是如何生成的。

代码地址:DINO代码


用于生成对比去噪正负样本的函数在DINO类的forward方法中被调用,名为prepare_for_dn,代码在项目中的路径为:DINO/models/dino/dn_components.py

 一、prepare_for_dn的调用

以下代码是prepare_for_dn()在DINO类的forward方法中的调用,每个参数的含义如下:

if self.dn_number > 0 and targets is not None:
    input_query_label, input_query_bbox, attn_mask, dn_meta = \
        prepare_for_cdn(
        dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale), 
        training=self.training, 
        num_queries=self.num_queries, 
        num_classes=self.num_classes, 
        hidden_dim=self.hidden_dim, 
        label_enc=self.label_enc)
  • dn_args:去噪训练的相关参数,包括标图像中的标注信息targets、去噪组数量self.dn_number(默认值100)、标签翻转概率self.dn_label_noise_ratio(默认值0.5)、box扰动尺度self.dn_label_noise_ratio(默认值0.4);
  • training:是否是训练过程,只有训练过程才有去噪部分;
  • num_queries:query的数量,默认值为900,用于生成attention mask,以阻止自注意力操作中去噪部分和匹配部分的信息泄露;
  • hidden_dim:嵌入向量的维度,默认为256维;
  • label_enc:nn.Embedding(dn_labelbook_size + 1, hidden_dim)初始化的实例,用于将label编码为向量;

二、prepare_for_dn代码解析

代码逻辑主要分为训练推理两部分,由一个if-else语句构成,之前提到过training参数用于判断是训练过程还是推理过程,推理过程的代码很简单,如下所示:

device = dn_args[0][0]['boxes'].device
    # 训练阶段,在此省略
    if training:
        pass
    
    # 推理阶段
    else:

        input_query_label = None
        input_quer
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值