【YOLO改进】主干插入SKAttention模块(基于MMYOLO)

SKAttention模块

论文链接:https://arxiv.org/pdf/1903.06586.pdf

将SKAttention模块添加到MMYOLO中

  1. 将开源代码SK.py文件复制到mmyolo/models/plugins目录下

  2. 导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS

  3. 确保 class SKAttention中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)

  4. 利用@MODELS.register_module()将“class SKAttention(nn.Module)”注册:

  5. 修改mmyolo/models/plugins/__init__.py文件

  6. 在终端运行:

    python setup.py install
  7. 修改对应的配置文件,并且将plugins的参数“type”设置为“BiLevelRoutingAttention”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-优快云博客

修改后的SK.py

from collections import OrderedDict
import torch
from torch import nn
from mmyolo.registry import MODELS

@MODELS.register_module()
class SKAttention(nn.Module):

    def __init__(self, in_channels=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        self.d = max(L, in_channels // reduction)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(in_channels, in_channels, kernel_size=k, padding=k // 2, groups=group)),
                    ('bn', nn.BatchNorm2d(in_channels)),
                    ('relu', nn.ReLU())
                ]))
            )
        self.fc = nn.Linear(in_channels, self.d)
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, in_channels))
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs = []
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats = torch.stack(conv_outs, 0)  # k,bs,channel,h,w

        ### fuse
        U = sum(conv_outs)  # bs,c,h,w

        ### reduction channel
        S = U.mean(-1).mean(-1)  # bs,c
        Z = self.fc(S)  # bs,d

        ### calculate attention weight
        weights = []
        for fc in self.fcs:
            weight = fc(Z)
            weights.append(weight.view(bs, c, 1, 1))  # bs,channel
        attention_weughts = torch.stack(weights, 0)  # k,bs,channel,1,1
        attention_weughts = self.softmax(attention_weughts)  # k,bs,channel,1,1

        ### fuse
        V = (attention_weughts * feats).sum(0)
        return V


if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    se = SKAttention
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值