不知道DETR怎么训练,来看看Query匹配GT的可视化过程

作者 | Mr.Jian  编辑 | 汽车人

原文链接:https://zhuanlan.zhihu.com/p/592381828

点击下方卡片,关注“自动驾驶之心”公众号

ADAS巨卷干货,即可获取

点击进入→自动驾驶之心【目标检测】技术交流群

DETR利用object query在feature map上全局预测bbox的坐标,并借助匈牙利算法(Hungarian algorithm)完成预测bbox与gt的匹配过程,整体结构还是和Transformer类似。

接下来我们来看看从初期-中期-后期的整个训练过程中,object query预测的bbox是如何一步步在全局坐标上靠近gt。

示意图含义,图像上的bbox分为两个部分,1. gt bbox,2. object query预测的pred bbox。

  1. gt bbox:用不同颜色来表示不同的gt bbox,在gt bbox的左上角用 Gt.i 表示当前的gt bbox是第 i 个gt,i从0开始。比如 Gt.1 表示第1个gt bbox。

  2. object query预测的pred bbox:pred bbox的颜色与gt bbox一一对应,相同颜色表示该pred bbox经过匈牙利算法匹配后负责预测这个gt bbox。在pred bbox的右上角用 Qu.i(取query的前2个字母)表示当前的pred bbox是第 i 个object query输出的。比如 Qu.34 表示当前的pred bbox是第 34 个object query输出的bbox。

先看个目标数量比较少的giraffe场景:有2只giraffe,用 Gt.0 和 Gt.1 表示。

刚初始化模型后object query输出的pred bbox,经过匈牙利算法匹配后:分别为 Qu.12 和 Qu.89 负责预测。

4ea0f208590f70f85ba2c28f5840fbf4.png
刚开始训练

经过45个step后,黄色绿色的pred bbox都朝右边的长颈鹿靠近,并且由 Qu.45 和 Qu.42 负责预测。

c606690ffa50cc33260b084986b234e8.png
step=40

经过80step后,黄色框的 Qu.31 开始向 Gt.1 靠近。

8612bcecbfc1f4f2d049c1b70d04e225.png
step=80

经过135个step后,黄色框的 Qu.86 越来越靠近 Gt.1,并且绿色框的 Qu.94 也更加接近 Gt.0 。

96276a91c4d25a0d29a1805a54c38020.png
step=135

上面的过程用GIF图看:可以发现每个step,分别负责预测 Gt.0 和 Gt.1 的object query是不断变化的。

模型训练后期到收敛状态:分别负责预测 Gt.0 和 Gt.1 的object query变化频率很小,Gt.0 基本由 Qu.98 负责,Gt.1由 Qu.2、Qu.61、Qu.67 负责,并且pred bbox更加稳定。

 训练后期,模型接近收敛

再看一个目标数量比较多的场景:多个person,和一个小目标frisbee,用 Gt.0 ~ Gt.7 表示。

模型初始化后,训练前期object query分别去找各自负责的gt bbox,对于中大目标person,object query在250个step后基本能找到,但是对于小目标frisbee,就比较难找了,会出现pred bbox在frisbee附近震荡。

 训练前期

训练中后期,pred bbox在各自的gt附近晃动,负责预测小目标frisbee的object query输出的bbox开始偏向稳定,Qu.i 频繁跳动。

训练中后期

训练到收敛阶段,负责预测各个gt bbox的object query变得稳定。

训练到收敛阶段

视频课程来了!

自动驾驶之心为大家汇集了毫米波雷达视觉融合、高精地图、BEV感知、传感器标定、自动驾驶协同感知、语义分割、自动驾驶仿真、L4感知等多个方向学习视频,欢迎大家自取(扫码进入学习)

14d9b60f6c47e531525bb7b62e96c10f.png

(扫码学习最新视频)

国内首个自动驾驶学习社区

近1000人的交流社区,和20+自动驾驶技术栈学习路线,想要了解更多自动驾驶感知(分类、检测、分割、关键点、车道线、3D目标检测、Occpuancy、多传感器融合、目标跟踪、光流估计、轨迹预测)、自动驾驶定位建图(SLAM、高精地图)、自动驾驶规划控制、领域技术方案、AI模型部署落地实战、行业动态、岗位发布,欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频,期待交流!

25502116e872eb7e3c25b92ecd47df78.jpeg

自动驾驶之心】全栈技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦目标检测、语义分割、全景分割、实例分割、关键点检测、车道线、目标跟踪、3D目标检测、BEV感知、多传感器融合、SLAM、光流估计、深度估计、轨迹预测、高精地图、NeRF、规划控制、模型部署落地、自动驾驶仿真测试、产品经理、硬件配置、AI求职交流等方向;

577bf3e36820dc8a738456ec574ec63e.jpeg

添加汽车人助理微信邀请入群

备注:学校/公司+方向+昵称

<think&gt;我们正在讨论DETR(DetectionTransformer)模型的训练结果可视化工具或方法。DETR是一种基于Transformer的目标检测模型,由FacebookAI提出。可视化训练结果通常包括损失曲线、准确率曲线、预测结果的可视化等。1.**训练过程中的指标可视化**:-使用TensorBoard或Weights&Biases(WandB)等工具来可视化训练过程中的损失和指标。这些工具可以实时监控训练过程,并绘制损失函数值、学习率、评价指标(如mAP)等随时间(或迭代次数)的变化曲线。-在训练脚本中,通常需要将日志写入特定的目录,然后通过TensorBoard读取该目录。2.**预测结果的可视化**:-在训练完成后,可以使用训练好的模型对验证集或测试集的图像进行预测,并将预测结果(边界框和类别)绘制在图像上。-DETR官方代码库(https://github.com/facebookresearch/detr)中提供了可视化预测结果的示例脚本。通常,我们可以编写一个脚本,加载训练好的模型,然后对单张或多张图像进行预测,并使用OpenCV或Matplotlib将预测的边界框和类别标签绘制在图像上。3.**注意力图的可视化**:-DETR使用Transformer的自注意力机制,因此我们可以可视化解码器中的注意力图,以观察模型在做出预测时关注图像的哪些区域。-官方代码库中也提供了可视化解码器注意力图的工具。4.**使用DETR提供的工具**:-在DETR的官方代码中,`main.py`训练脚本支持TensorBoard日志记录。此外,`engine.py`中包含了训练和验证的代码,其中在验证时计算评价指标并记录日志。-在`util`目录下,有`plot_utils.py`等工具,可用于绘制边界框和注意力图。具体步骤:1.**训练日志可视化(TensorBoard)**:假设训练时使用了`--output_dir`参数指定输出目录,训练日志会保存在该目录下。启动TensorBoard:```bashtensorboard--logdir=<path_to_output_dir&gt;```然后在浏览器中打开TensorBoard的地址(默认为localhost:6006)即可查看各种指标。2.**预测结果可视化**:可以参考以下步骤:-加载预训练模型-对一张图像进行预处理(转换为模型输入的格式)-运行模型得到预测结果-将预测结果(边界框和类别)绘制在图像上官方提供的示例代码片段(需要根据实际情况调整):```pythonimporttorchfrommodelsimportbuild_modelfromPILimportImageimportmatplotlib.pyplotaspltimporttorchvision.transformsasT#加载模型model,_,_=build_model(args)#需要根据训练时的参数构建模型checkpoint=torch.load('path/to/checkpoint.pth',map_location='cpu')model.load_state_dict(checkpoint['model'])model.eval()#图像预处理transform=T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])#加载图像img=Image.open('image.jpg')img_tensor=transform(img).unsqueeze(0)#预测withtorch.no_grad():outputs=model(img_tensor)#将输出转换为边界框和标签#注意:DETR的输出需要经过后处理(如使用softmax获取概率,并应用阈值过滤)#具体后处理可参考官方代码中的postprocessing部分#可视化函数(可参考官方代码中的plot_utils)defplot_results(pil_img,prob,boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax=plt.gca()forp,(xmin,ymin,xmax,ymax)inzip(prob,boxes):ax.add_patch(plt.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color='red',linewidth=3))cl=p.argmax()text=f'{cl}:{p[cl]:0.2f}'ax.text(xmin,ymin,text,fontsize=15,bbox=dict(facecolor='yellow',alpha=0.5))plt.axis('off')plt.show()#假设已经得到prob和boxes,调用plot_results#plot_results(img,probabilities,boxes)```3.**注意力图可视化**:官方提供了`visualize_attention.py`脚本(在detr的util目录下),可以用于可视化解码器的注意力图。使用方式大致如下:```bashpythonvisualize_attention.py--image_path<path_to_image&gt;--resume<path_to_checkpoint&gt;--threshold0.5```4.**其他可视化**:还可以使用第三方工具,比如Netron(https://github.com/lutzroeder/netron)来可视化模型结构。注意:以上代码和脚本可能需要根据DETR的具体版本和训练配置进行调整。引用说明:DETR官方代码库提供了丰富的工具和示例,我们可以直接使用或参考这些工具进行可视化[^1]。相关问题:1.DETR模型训练过程中如何监控损失函数的变化?2.如何可视化DETR模型在单张图像上的预测结果?3.DETR模型的注意力图可视化有什么作用?4.除了官方工具,还有哪些第三方工具可以用于DETR模型的可视化?</think&gt;### DETR模型训练结果可视化方法 DETR(DEtection TRansformer)的训练结果可视化主要包含以下几个层面: #### 1. **训练指标可视化** 使用日志工具记录并可视化训练过程中的关键指标: - **损失曲线**:分类损失、边界框损失、GIoU损失等 - **评估指标**:mAP(平均精度)、召回率等 - **学习率变化曲线** ```python # 使用TensorBoard记录日志 from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter('logs/') for epoch in range(epochs): # 训练过程... writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('mAP/val', val_map, epoch) ``` 启动可视化: ```bash tensorboard --logdir=logs ``` #### 2. **预测结果可视化** 对验证集图像进行预测并可视化检测结果: ```python import matplotlib.pyplot as plt from detr.util.plot_utils import plot_results def visualize_predictions(model, image_path): # 预处理图像 transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = Image.open(image_path).convert("RGB") img_tensor = transform(img).unsqueeze(0) # 模型预测 with torch.no_grad(): outputs = model(img_tensor) # 可视化预测结果 plt.figure(figsize=(16,10)) plt.imshow(img) ax = plt.gca() plot_results(ax, img, outputs['pred_boxes'][0], outputs['pred_logits'][0]) plt.axis('off') plt.savefig('predictions.jpg') ``` #### 3. **注意力图可视化** 可视化Transformer解码器的注意力权重: ```python # 使用官方提供的注意力可视化工具 from detr.models.matcher import HungarianMatcher from detr.util.misc import nested_tensor_from_tensor_list # 获取注意力权重 attentions = model(images, return_attn=True)['attentions'] # 可视化特定解码器层的注意力 plt.imshow(attentions[5][0, 3].cpu()) # 第6层第4个头 plt.colorbar() ``` #### 4. **模型结构可视化** 使用工具可视化模型架构: - Netron(https://github.com/lutzroeder/netron) - PyTorchViz(`torchviz`库) ```python from torchviz import make_dot # 生成计算图 x = torch.randn(1, 3, 800, 800) outputs = model(x) make_dot(outputs, params=dict(model.named_parameters())).render("detr_arch") ``` #### 5. **部署可视化工具** - **Detectron2**:集成DETR可视化接口 - **FiftyOne**:交互式数据集和结果可视化 ```python import fiftyone as fo # 创建可视化会话 dataset = fo.Dataset("validation") session = fo.launch_app(dataset) session.view = fo.load_validation_predictions(model) ``` #### 关键注意事项 1. 可视化前需对模型输出进行后处理(使用`postprocess`方法) 2. 边界框需从$[cx, cy, w, h]$格式转换为$[x_min, y_min, x_max, y_max]$ 3. 注意力图可视化需注意归一化处理 4. 目标查询(object queries)的可视化可揭示解码器工作机理[^1] [^1]: 参考DETR官方仓库可视化示例:https://github.com/facebookresearch/detr/blob/main/notebooks/DETR_predictions_visualization.ipynb
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值