代码理解DETR和DeformableDETR

以下内容记录我在阅读DETR和Deformable DETR论文及源码过程中对网络结构、以及自注意力、可变形注意力的理解,可能存在一些错误,欢迎讨论~

代码理解DETR

重绘网络图

image-20241114154934049

理解

encoder:

  1. Q/K/V的内容信息是一样的,都是输入的图像特征,Q/K还增加了位置信息;
  2. Q/K增加的位置信息是一样的,都是利用Sin-Cos生成的固定位置信息;
  3. 各个encoder layer的输出值都是施加了自注意力后的、各位置关注了其他所有位置信息后而优化出的、更好的feature representations;(体现Transformer的全局依赖建模能力)
  4. 新的feature representations作为下一层encoder layer的输入特征,再结合Sin-Cos生成的固定位置信息pos_embed,一起作为下一层encoder layer的输入;
  5. 各个encoder layerSin-Cos固定位置信息pos_embed都是一样的,只在最初生成一次;
  6. 最后整个encoder的输出值(即memory),就是充分施加了自注意力后的、各位置充分关注了其他所有位置信息后而优化出的、更好的feature representations
  7. memory作为输入decoderK/V的内容信息,K还需要加上Sin-Cos的固定位置信息pos_embed,二者一起作为cross-attention输入

decoder:

  1. query_embed代表object query的位置信息,是可学习的;
  2. 初始化时object query被初始化为(100, 256)的全零向量;
  3. self-attention中,Q/K/V拥有同样的内容信息,K/V还增加了相同的可学习的位置信息query_embed
  4. self-attention的输出值是施加了自注意力后的、各object query关注了其他所有位置object query信息后而优化的、更好的object query(内容信息更优质的object query);
  5. cross-attention中,Q是self-attetion输出的object query再加上可学习的位置信息query_embedK/V包含memory的内容信息,K还增加了Sin-Cos的固定位置信息pos_embed
  6. cross-attention的输出值代表:根据object query信息(由Q代表)去图像特征(由memory得到的K/V代表)查找到最相关的图像特征,这个图像特征会作为下一个decoder layerobject query进行输入

代码理解Deformable DETR

重绘网络图

DeformableDETR-pure.drawio

理解

encoder:

  1. Q/K/V的内容信息是一样的,都是输入的多尺度图像特征,Q/K还增加了位置信息;
  2. Q/K增加的位置信息是一样的,即level_pos_embed,它包括利用Sin-Cos生成的固定位置信息pos_embed_per_level和可学习的区分不同尺度特征层的位置信息scale_level_embed
  3. sampling_locations = references_points + sampling_offsets;其中references_points代表各特征层的grid center point位置全部归一化后再汇总,最终得到的全部特征点的坐标(所有特征层的会一起投影到每个特征层),sampling_offsets是网络预测的采样点的偏移位置(个数为num_levels*num_samping_points),sampling_locations 为最终deformable attention关注的采样点位置;
  4. Multi Scale Deformable Attention中,Q通过Linear得到针对sampling points的注意力权重attention weights,V通过Linear得到Value(代表所有reference pointsvalue),要想加权求和,还需要从V中精选出sampling points位置的value,所以这里的输入包括3个新增红线;
  5. 各个encoder layer的输出值都是施加了可变形注意力后的、各位置关注了其他几个重要位置信息后而优化出的、更好的feature representations;(体现Deformable Attention的关注重点能力)
  6. 新的feature representations作为下一层encoder layer的输入特征,再结合位置信息level_pos_embed,一起作为下一层encoder layer的输入;
  7. 在一次前向传播过程中,各个encoder layer在输入时添加的位置信息level_pos_embed都是一样的;
  8. 最后整个encoder的输出值(即memory),就是充分施加了可变形注意力后的、各位置充分关注了其他几个重要位置信息后而优化出的、更好的feature representations
  9. memory作为输入decoder的V的内容信息,作为cross-attention的输入

decoder:

  1. query position embeding代表object query的位置信息,是可学习的;
  2. object queryquery position embeding都是可学习的,维度都为(300, 256);
  3. self-attention中,Q/K/V拥有同样的内容信息,K/V还增加了相同的可学习的位置信息query position embedding
  4. self-attention的输出值是施加了自注意力后的、各object query关注了其他所有位置object query信息后而优化的、更好的object query(内容信息更优质的object query);
  5. cross-attention(Multi Scale Deformable Attention)中,Q是self-attention输出的object query再加上可学习的位置信息query position embedding,V包含memory的内容信息;为了执行可变形注意力,还需要找到采样点sampling points,所以还需要输入reference pointsspatial_shapeslevel_start_index
  6. 注意这里的reference points是可学习的,和encoder中固定位置的reference points不同;
  7. cross-attention的输出值代表:根据object query信息(由Q代表)去图像特征(由memory得到的V代表)查找到最相关的图像特征,这个图像特征会作为下一个decoder layerobject query进行输入;
  8. 在一次前向传播过程中,每一个decoder layer层输入的reference pointslevel_start_indexspatial_shapes都是相同的;其中level_start_indexspatial_shapesencoderdecoder中都是相同的

Self-Attention和Deformable Attention

  • 正如前面黄色高亮的部分,Self-Attention关注序列中一个位置的token和其他所有位置(包括自身)token的相互依赖关系,这体现了Transformer的全局建模能力(relation modeling capability);
  • Deformable Attention关注序列中一个位置的token和其他几个重要位置(即采样点sampling points)的token的相互依赖关系,这体现了Deformable Attention的稀疏空间采样(sparse spatial sampling)能力
### DETR代码实现 DETR (End-to-End Object Detection with Transformers) 是一种基于 Transformer 的端到端目标检测方法。其官方实现由 Facebook 提供,主要采用 PyTorch 编写[^1]。 以下是 DETR 官方代码库中的核心部分: #### 依赖环境设置 为了运行 DETR代码实现,首先需要安装必要的开发工具 Python 库。推荐使用 Visual Studio Code 进行开发调试。可以通过以下命令安装基础依赖项: ```bash pip install torch torchvision cython pip install -r requirements.txt ``` #### 数据集准备 DETR 支持 COCO 数据集作为训练数据源。下载并解压 COCO 数据集后,需配置路径以便模型能够访问这些文件[^3]。 #### 主要模块解析 1. **Transformer Encoder Decoder** - 在 DETR 中,Encoder 负责提取全局上下文信息,通过在每个注意力机制中加入位置编码来增强特征的相关性。 - Decoder 则用于生成对象查询(Object Queries),并通过多头注意力机制逐步细化预测结果。 2. **损失函数设计** - 使用匈牙利算法匹配预测框与真实标注框,从而最小化整体代价矩阵。 #### 训练脚本示例 下面是一个简单的训练脚本模板: ```python import argparse from detr.models import build_model from datasets.coco import CocoDetection, make_coco_transforms from engine import train_one_epoch, evaluate def main(args): dataset_train = CocoDetection( image_set='train', args=args, transforms=make_coco_transforms('train') ) model, criterion, postprocessors = build_model(args) model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch ) if __name__ == '__main__': parser = argparse.ArgumentParser() # 添加参数... args = parser.parse_args() main(args) ``` 对于 Meta-DETR,这是一个针对少样本目标检测的改进版本,具体实现可参考该项目仓库[^2]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值