YOLO系列:改进YOLOv8——以添加Gam注意力模块为例

本文介绍了如何在YOLOv8模型中添加Gam注意力机制,包括源码实现、模块注册、调用以及如何在训练过程中配置和应用。作者提醒注意注意力对不同数据集可能带来的效果差异。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、Gam注意力源码

import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels,c2, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

二、添加方法

此方法仅适用于新版YOLOv8,旧版YOLOv8添加方法略有不同

1、添加注意力源码

在ultralytics/nn/modules/conv.py文件内添加注意力源码

 2、注册并引用注意力

在ultralytics/nn/modules/__init__.py文件内,按下图标识的地方添加注意力名

第一处:在from .conv import()处最后,添加注意力名称

第二处:在__all__={}处最后,添加注意力名称

 3、调用注意力

在ultralytics/nn/tasks.py文件内,键盘点击CTRL+shift+F打开查找界面,搜索

def parse_model(d, ch, verbose=True):

在该函数下方有一堆的elif m in XXX,在某一个elif下方添加如下代码:

        elif m in {GAM_Attention}:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

4、完成配置

在ultralytics/cfg/models/v8文件下,复制yolov8.yaml,并改成自己的名字,复制对应注意力的代码,这里我以Gam注意力为例(不同注意力的配置代码不同,请读者自行修改)

图中nc代表着你自己数据集标签的数量

5、进行训练

在YOLOv8源文件夹下,新建train.py,

from ultralytics import YOLO
if __name__ == '__main__':
    # 加载模型
    model = YOLO("yolov8-NAMAttention.yaml")  # 从头开始构建新模型
    #model = YOLO("yolov8x.pt")  # 加载预训练模型(推荐用于训练)

    # Use the model
    results = model.train(data="data/detect_plane.yaml", epochs=500, batch=8, workers=1, close_mosaic=0, name='cfg')  # 训练模型
    # results = model.val()  # 在验证集上评估模型性能
    # results = model("https://ultralytics.com/images/bus.jpg")  # 预测图像
    # success = model.export(format="onnx")  # 将模型导出为 ONNX 格式

其中model代表着你刚刚新建立的yaml文件名,也就是模型的名称,results代表着你数据集的配置文件,我的配置文件是上一篇博客讲的计挑赛的数据集配置文件。

最后,用命令行开始训练

python train.py

三、附言

注意力不一定会在所有数据集均有精度或者速度的提升,有些注意力只会在特定数据集有小幅度的数据提升,所以读者需要根据自己数据集的特点进行注意力的选择!

### 改进YOLOv8中的GAM模块 #### 修改配置文件 为了改进YOLOv8中的GAM模块,在`yolov8-NAMAttention.yaml`中定义新的网络架构,确保引入了GAM注意力机制的相关参数设置[^2]。 ```yaml # yolov8-NAMAttention.yaml example snippet backbone: ... - from: [-1] module: models.common.C3 args: [64, 64, 3, False, GAM_Attention()] ... ``` #### 添加自定义层 在`ultralytics/nn/modules/GAM_attention.py`创建并实现GAM注意力类。此操作基于已有研究工作进行了适当调整以适应YOLOv8框架的需求[^4]。 ```python import torch.nn as nn class GAM_Attention(nn.Module): def __init__(self, channels_in, rate=4): super(GAM_Attention, self).__init__() self.channel_attention = nn.Sequential( nn.Linear(channels_in, int(channels_in / rate)), nn.ReLU(inplace=True), nn.Linear(int(channels_in / rate), channels_in) ) def forward(self, x): b, c, _, _ = x.size() y = self.channel_attention(x.view(b,c,-1).mean(-1)) y = y.sigmoid().view(b,c,1,1) return x * y.expand_as(x) ``` #### 更新导入路径 确认已在`ultralytics/nn/__init__.py`内更新`__all__`列表,加入新组件名称以便后续调用: ```python from .modules.GAM_attention import GAM_Attention __all__ += ["GAM_Attention"] ``` #### 调整训练脚本 最后一步是在主程序入口处加载含有GAM注意力的新模型配置,并指定数据集及其他超参来启动训练过程。 ```python if __name__ == '__main__': # 加载带有GAM注意力YOLOv8模型 model = YOLO("yolov8-NAMAttention.yaml") # 开始训练 results = model.train(data="path/to/data.yaml", epochs=500, batch=8, ...) ``` 通过上述改动,能够有效地将GAM注意力集成到YOLOv8当中,从而提升其对于复杂场景下的目标识别能力[^1]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值