江大白 | 终于把RT-DETR搞懂!替代YOLO的更快实时目标检测算法及Pytorch实现【附论文及源码,建议收藏!】

本文来源公众号“江大白”,仅用于学术分享,侵权删,干货满满。

原文链接:终于把RT-DETR搞懂!替代YOLO的更快实时目标检测算法及Pytorch实现【附论文及源码】

导读

YOLO通过非极大值抑制过滤重叠边界框,增加了计算延迟。RT-DETR的出现改变了这一现状,RT-DETR取消了NMS后处理,结合强大的主干网络、混合编码器和创新的查询选择器,提供了快速且高效的端到端目标检测解决方案。

论文链接:https://arxiv.org/pdf/2304.08069

源码链接:https://github.com/lyuwenyu/RT-DETR

引言

目标检测一直面临着一个重大挑战-平衡速度和准确性。像YOLO这样的传统模型速度很快,但需要一个名为非极大值抑制(NMS)的后处理步骤,这会减慢检测速度。NMS过滤重叠的边界框,但这会引入额外的计算时间,影响整体速度。这就是DETR(https://arxiv.org/pdf/2005.12872)(检测Transformer)的用武之地。

RT-DETR是基于DETR架构的端到端对象检测器,完全消除了对NMS的需求。通过这样做,RT-DETR显着减少了之前基于卷积神经网络(CNN)的对象检测器(如YOLO系列)的延迟。它结合了强大的主干、混合编码器和独特的查询选择器,可以快速准确地处理特征。

RT-DETR架构概述,来源:https://arxiv.org/pdf/2304.08069

RT-DETR架构的关键组成

  1. 主干:主干从输入图像中提取特征,最常见的配置是使用ResNet-50或ResNet-101。从主干,RT-DETR提取三个级别的特征- S3,S4和S5。这些多尺度特征有助于模型理解图像的高级和细粒度细节。

残差构建块,来源:https://arxiv.org/pdf/1512.03385

2.混合编码器

img

混合编码器

该混合编码器包括两个主要部分:基于注意力的尺度内特征交互(AIFI)和跨尺度特征融合(CCFF)。以下是每个部分的工作原理:

  • AIFI:这一层只处理主干中的S5特征图。由于S5代表最深的层,因此它包含关于图像中完整对象及其上下文的最丰富的语义信息-这使得它非常适合基于变换的注意力来捕捉对象之间有意义的关系。虽然S3和S4包含有用的低级别特征,如边缘和对象部分,但对这些层施加注意力将在计算上昂贵且效率较低,因为它们尚未表示需要彼此相关的完整对象。

以下是AIFI的PyTorch实现。代码分为三个主要部分:

  1. 处理关注后要素的前馈层(***FFLayer***)

  2. 实现多头自关注的Transformer编码器层

  3. 主要的AIFI模块,通过适当的特征投影和位置编码将所有内容联系在一起

import torch
import torch.nn as nn

class FFLayer(nn.Module):
    '''
    Feed-forward network: two linear layers with a non-linear activation
    in between
    '''
    def __init__(self,
                 embedd_dim:int,
                 ffn_dim:int,
                 dropout:float,
                 activation:str) -> nn.Module:
        super().__init__()
        self.feed_forward = nn.Sequential(
                nn.Linear(embedd_dim, ffn_dim),
                getattr(nn, activation)(),
                nn.Dropout(dropout),
                nn.Linear(ffn_dim, embedd_dim)
            )

        self.norm = nn.LayerNorm(embedd_dim)    
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        residual = x
        x = self.feed_forward(x)
        out = self.norm(residual + x)
        return out

class TransformerEncoderLayer(nn.Module):
    
    def __init__(self,
                 hidden_dim:int,
                 n_head:int,
                 ffn_dim:int,
                 dropout:float,
                 activation:str = "ReLU"):  
            
        super().__init__()    
        self.mh_self_attn = nn.MultiheadAttention(hidden_dim, n_head, dropout, batch_first=True)
        self.feed_foward_nn = FFLayer(hidden_dim,ffn_dim,dropout,activation)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self,x:torch.Tensor,mask:torch.Tensor=None,pos_emb:torch.Tensor=None) -> torch.Tensor:
        # Save the input as the residual for the skip connection
        residual = x
        
        # Add positional embeddings (if provided) to the input for spatial encoding
        q = k = x + pos_emb if pos_emb is not None else x

        # Apply multi-head self-attention
        x, attn_score = self.mh_self_attn(q, k, value=x, attn_mask=mask)
        
        # Add residual connection and apply dropout
        x= residual + self.dropout(x)
        # Normalize the result to avoid distribution shift during training
        x = self.n
### RT-DETR-n 代码实现源码解析 RT-DETR 是一种基于 Transformer 的目标检测模型,其在实时性和准确性之间取得了良好的平衡。Ultralytics 提供了 RT-DETR实现,而 RT-DETR-n 则是该系列模型的一个变种,通常表示较小的模型尺寸以适应轻量级应用场景。 以下是对 RT-DETR-n 的代码实现源码解析: #### 1. 模型结构概述 RT-DETR-n 的核心架构基于 Transformer 编码器-解码器结构,并结合了 DETR 中的 Set Prediction 方法。具体来说,RT-DETR-n 包含以下几个关键组件: - **骨干网络(Backbone)**:用于提取图像特征。RT-DETR-n 可能使用了更小的骨干网络(如 EfficientNet 或 MobileNet),以减少计算复杂度[^1]。 - **Transformer 编码器**:对骨干网络提取的特征进行全局上下文建模。 - **Transformer 解码器**:生成目标框预测和类别预测。 - **匹配策略**:采用 Hungarian Matching 算法来优化预测结果与真实标签之间的匹配[^2]。 #### 2. 源码解析 以下是 RT-DETR-n 在 Ultralytics 实现中的关键代码模块解析: ##### (1) 骨干网络定义 RT-DETR-n 使用轻量化骨干网络,例如 MobileNetV3 或 EfficientNet-B0。以下是一个示例代码片段: ```python from ultralytics.nn.backbones import MobileNetV3 def build_backbone(model_size="n"): if model_size == "n": return MobileNetV3(out_stages=[2, 3, 4]) # 输出不同阶段的特征图 ``` ##### (2) Transformer 编码器与解码器 Transformer 的实现基于 PyTorch 的 `nn.Transformer` 模块。以下是简化版的代码: ```python import torch.nn as nn class Transformer(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6): super().__init__() self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead), num_layers=num_encoder_layers ) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead), num_layers=num_decoder_layers ) def forward(self, src, tgt): memory = self.encoder(src) output = self.decoder(tgt, memory) return output ``` ##### (3) 匹配与损失函数 RT-DETR-n 使用 Hungarian Matching 来优化目标框与预测框的匹配,并结合多种损失函数(如 L1 Loss 和 GIoU Loss)来训练模型。以下是匹配逻辑的简化实现: ```python from scipy.optimize import linear_sum_assignment class HungarianMatcher: def __init__(self, cost_class=1, cost_bbox=5, cost_giou=2): self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou def __call__(self, outputs, targets): out_prob = outputs["pred_logits"].softmax(-1) # [batch_size, num_queries, num_classes] out_bbox = outputs["pred_boxes"] # [batch_size, num_queries, 4] indices = [] for i in range(len(targets)): cost_class = -out_prob[i][:, targets[i]["labels"]].cpu() cost_bbox = torch.cdist(out_bbox[i], targets[i]["boxes"], p=1).cpu() cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox[i]), box_cxcywh_to_xyxy(targets[i]["boxes"])).cpu() C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou indices.append(linear_sum_assignment(C)) return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] ``` ##### (4) 训练流程 训练 RT-DETR-n 的主要步骤包括数据预处理、模型初始化、损失计算和优化器配置。以下是训练脚本的核心部分: ```python from ultralytics.engine.trainer import Trainer def train_rt_detr_n(data_config, model_size="n", epochs=100): model = RT_DETR(model_size=model_size) trainer = Trainer(model=model, data=data_config, epochs=epochs) trainer.train() ``` #### 3. 改进策略 对于 RT-DETR-n,可以参考 Mamba-YOLO 的改进策略,通过替换骨干网络为 MambaNet,进一步提升性能。MambaNet 是一种高效的卷积神经网络,能够在保持精度的同时降低计算成本。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值