详解多模态(红外-可见光图像)目标检测模型SuperYOLO源码,真正搞清代码逻辑!

1. 文章主要内容

       本文主要是详细分析SuperYOLO多模态源代码,包括如何启动,以及详细代码部分如何改进,从而让单模态的检测支持为多模态的检测!。基于YOLO的单模态检测赛道已经非常卷,很难出好的论文,这个时候入门多模态检测是非常有必要的!所以,本篇代码分析论文则是入门基于YOLO的多模态目标检测的基础之一。

2. 相关说明

       本篇博文代码来源于github地址:SuperYOLO源代码)
       本人推荐安装运行自己数据集的博客为:SuperYOLO安装以及训练自定义数据集,这篇博客我粗略看了一些,总结了一些运行SuperYOLO项目常见的错误,非常推荐!
       SuperYOLO的模型是基于YOLOv5框架,本文主要是分析源码如何让支持单模态检测的YOLO框架支持多模态检测。能够深刻理解源码是入门多模态的必经之路,不然之后有点子也写不出来。如果我们理解了源码分析,完全可以将多模态的逻辑迁移到YOLOv8、YOLOv11等优秀的模型。
       SuperYOLO的模型使用的数据集为VEDAI,为遥感领域红外-可见光的多模态目标检测数据集。需要注意的是:SuperYOLO原论文明确说是进行像素级别的融合,也就是早期融合,如果不知道多模态图像融合类别的同学,自己先去了解一下。
       运行的命令我使用的是这个:python train.py --cfg models/SRyolo_MF.yaml --super --train_img_size 1024 --hr_input --data data/SRvedai.yaml --ch 64 --input_mode RGB+IR+MF
       注意到:早期融合一般比较简单,因为在送进主干网络之前,两种模态一般就已经进行了融合,也就是说模型的yaml文件不会出现RGB可见光和IR红外两种主干。这也说明早期融合一般比中期融合(也叫做特征级融合)简单,如果对特征级融合感兴趣的话,博主这里写了一篇关于基于YOLOv8的多模态特征级别融合的源代码分析:一文详解YOLOv8多模态目标检测(可见光+红外图像,基于Ultralytics官方代码实现),轻松入门多模态检测领域!

3. 基于SuperYOLO的多模态目标检测

       这一块分为两个部分,第一块是启动运行部分,启动运行部分我不讲,上述我推荐了一个运行启动的博客,如果有问题可以评论区多多交流。第二块是多模态代码的分析)第二块是重点,因为后续想改进代码必须搞懂如何进行模型改进和前向传播的改进等。

3.1 详解代码流程(重点)

       这一块的内容主要是从train.py文件,一步步去分析如何构造多模态目标检测的,这里先给出一张函数的流程图,下面的内容就是根据这张图来说明,注意我不会讲所有的代码,只会讲牵扯多模态相关需要修改的代码部分。
在这里插入图片描述

3.1.1 train.py文件(入口)

       从源代码的train.py着手,可以看到这一行代码:train(hyp, opt, device, tb_writer),于是我们进去这个train函数里面。然后可以看到这一行代码: model = Model(opt.cfg, input_mode = opt.input_mode ,ch_steam=opt.ch_steam,ch=opt.ch, nc=nc, anchors=hyp.get('anchors'),config=None,sr=opt.super,factor=down_factor).to(device),这个时候我们要进入Model这个类中,这个类就是通过yaml构造model结构的。需要注意一下,这行代码的ch=opt.ch,这个ch代表的是yaml第一行,也就是第一行module的输出通道数。

3.1.2 SRyolo.py文件

       Model类来自SRyolo.py文件,根据yaml解析模型的逻辑在self.model, self.save = parse_model代码中,于是我们进去parse_model这个函数中,没有什么特别之处。后续回来会看其前向传播算法。然后回到train.py文件中,看到这两行代码:from utils.datasets import create_dataloader_sr as create_dataloaderdataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, #world_size=opt.world_size, workers=opt.workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))。于是我们来到create_dataloader_sr这个函数。

3.1.3 datasets.py文件

       create_dataloader_sr来自datasets.py文件中,可以看到如下代码:这个函数是用来根据路径去加载image和label的,非常关键,我们进去看看。首先看其__init__构造函数方法,看到这些代码:这里面的img_files即为可见光RGB图像的路径,接下来label_files 获取img_files的标签,证明我们使用的标签是RGB可见光图像的标签,然后ir_files 则为获取IR红外图像的路径。

 with open(path, "r") as file:
            self.img_files = file.readlines()
            # for i in dele:
            #     if i+'\n' in self.img_files:
            #         self.img_files.remove(i+'\n')
            for j in range(len(self.img_files)):
                self.img_files[j] = self.img_files[j].rstrip() + '_co.png'
                
 self.label_files = img2label_paths(self.img_files)  # labels
        self.ir_files = img2ir_paths(self.img_files)

       然后我们来看看__getitem__函数,这个函数的作用是遍历dataloader的时候,有类似这样的代码self.dataset[i]这样的代码,那么就会执行这个函数的代码。可以看到通过索引值分别加载了img和ir的图片,最后以return的形式返回去赋值给dataset。

img, (h0, w0), (h, w) = load_image(self, index)
            ir = load_ir(self, index) #zjq
 return torch.from_numpy(img), torch.from_numpy(ir), labels_out, self.img_files[index], shapes
 dataset = LoadImagesAndLabels_sr(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyperparameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=opt.single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)

       得到dataset后,来到如下的dataloader代码。这里注意到collate_fn这个函数,作用是对loader函数返回值进行批次打包,就是对dataset返回值进行打包,一个batch一个batch的,下面在训练的时候会再次讲到,这里先提示一下。OK,我们得到了dataloader函数。于是回到train.py文件中。

dataloader = loader(dataset,
                        batch_size=batch_size,
                        num_workers=nw,
                        sampler=sampler,
                        pin_memory=True,
                        collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    return dataloader, dataset

3.1.4 再次回到train.py文件

       看到这两行代码:pbar = enumerate(dataloader)for i, (imgs, irs, targets, paths, _) in pbar:。对于第二个代码,我们在这里就开始遍历pbar,也就是dataloader,这里是一个批次一个批次遍历。这里的imgs, irs, targets, paths, _即为上述__getitem__的返回类型,而遍历dataloader就会触发__getitem__方法,也就会进行load_image方法。之后,我们看到这行代码:model(imgs,irs,opt.input_mode),于是发现我们将img、irs传进去,这个时候我们就得去看Model类的前向传播函数forward,看看其对输入的执行逻辑是什么。

3.1.5 再次回到SRyolo.py文件

       可以看到这个函数:def forward(self, x, ir=torch.randn(1,3,512,512), input_mode='RGB+IR', augment=False, profile=False):,我们在最开始的输入启动命令的input_mode为RGB+IR+MF,所以会执行如下两行代码:这个steam值代表首先对应ir的**第二个通道(也就是C通道维度)**取第一个通道值,也就是说ir本来通过load_ir的cv.read函数得到的是三通道,这里需要改为1通道。之后,再与可将光x组成一个长度为2的list数组。之后看到这一行代码:y,features = self.forward_once(steam,'yolo', profile) #zjq,我们来到forward_once这个方法,然后会执行 x = m(x) # run(yaml文件第一层只是将Conv改为MF),之后也就会跳到common.py中的MF方法,如下代码所示:MF的forward会对RGB和IR进行分别处理,然后进行融合,后面的yaml文件无需修改,因为经过这个融合之后就只有一个输入了。

                if input_mode == 'RGB+IR+MF':
                    steam = [x,ir[:,0:1,:,:]] #[:,0:1,:,:]
class MF(nn.Module):# stereo attention block
    def __init__(self, channels):
        super(MF, self).__init__()
        self.mask_map_r = nn.Conv2d(channels, 1, 1, 1, 0, bias=True)
        self.mask_map_i = nn.Conv2d(1, 1, 1, 1, 0, bias=True)
        self.softmax = nn.Softmax(-1)
        self.bottleneck1 = nn.Conv2d(1, 16, 3, 1, 1, bias=False)
        self.bottleneck2 = nn.Conv2d(channels, 48, 3, 1, 1, bias=False)
        self.se = SE_Block(64,16)
        # self.se_r = SE_Block(3,3)
        # self.se_i = SE_Block(1,1)


    def forward(self, x):# B * C * H * W #x_left, x_right
        x_left_ori, x_right_ori = x[0],x[1]
        b, c, h, w = x_left_ori.shape
        # x_left = self.se_r(x_left_ori)
        # x_right = self.se_i(x_right_ori)
        x_left = x_left_ori*0.5
        x_right = x_right_ori*0.5

        x_mask_left = torch.mul(self.mask_map_r(x_left).repeat(1,3,1,1),x_left)
        x_mask_right = torch.mul(self.mask_map_i(x_right),x_right)
       

        out_IR = self.bottleneck1(x_mask_right+x_right_ori)
        out_RGB = self.bottleneck2(x_mask_left+x_left_ori) #RGB
        out = self.se(torch.cat([out_RGB,out_IR],1))
        # import scipy.io as sio
        # sio.savemat('features/output.mat', mdict={'data':out.cpu().numpy()})

        return out

3. 总结

       大概就是这些代码,可能有一些细节没讲解,这是属于早期融合,也就是像素级别的融合,希望大家能有收获,如果有任何疑问,可以评论区交流!如果可以的话,希望大家多多点赞,收藏,后续会更新相关代码和论文的解读!

### 双模态目标检测数据集下载 对于双模态目标检测,特别是涉及红外可见光图像的任务,存在多个公开可用的数据集可以用于研究和开发工作。以下是几个常用的数据集及其获取方式: #### 1. KAIST Multispectral Pedestrian Dataset 该数据集由韩国科学技术院(KAIST)提供,包含了大量行人标注的RGB-thermal(可见光-热成像)配对图片。这些图像是在不同天气条件、时间和场景下采集得到的。 访问地址:[KAIST Multi-spectral Pedestrian Detection Benchmark](http://www.kitech.re.kr/) 注册并同意使用条款后即可下载完整的数据包[^1]。 #### 2. FLIR ADAS Dataset FLIR公司发布的这个自动驾驶辅助系统专用数据集中也含有丰富的RGB-T(彩色视觉与长波红外线)同步视频片段以及对应的边界框标签信息。它非常适合用来训练基于深度学习的目标识别算法。 官方链接:[FLIR Publicly Available Thermal Camera Dataset](https://github.com/flir/adas) 可以直接克隆仓库来获得所需资源。 #### 3. ODIHT (Occluded-Divided Infrared Human Tracking) ODIHT是一个专注于遮挡情况下人体追踪挑战性的数据库,其中同样提供了大量的可见光加远红外组合样本供科研人员测试自己的方法效果如何。 项目主页:[ODIHT Project Page](http://cvlab.cse.msu.edu/project-odiht.html) 按照页面指示完成申请流程之后便能取得授权以下载全部资料。 为了更好地利用上述任何一个或多于一个数据源来进行实验,在实际操作前建议先仔细阅读各自附带文档说明中的细节描述部分,确保所选方案能够满足具体应用场景的需求特点。 ```bash # 使用git命令行工具拉取FLIR ADAS数据集 $ git clone https://github.com/flir/adas.git ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

弗兰随风小欢

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值