ShuffleAttention模块
论文链接:https://arxiv.org/abs/2102.00240
将ShuffleAttention模块添加到MMYOLO中
-
将开源代码ShuffleAttention.py文件复制到mmyolo/models/plugins目录下
-
导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS
-
确保 class ShuffleAttention中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)
-
利用@MODELS.register_module()将“class ShuffleAttention(nn.Module)”注册:
-
修改mmyolo/models/plugins/__init__.py文件
-
在终端运行:
python setup.py install
-
修改对应的配置文件,并且将plugins的参数“type”设置为“ShuffleAttention”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-优快云博客
修改后的ShuffleAttention.py
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
from mmyolo.registry import MODELS
@MODELS.register_module()
class ShuffleAttention(nn.Module):
def __init__(self, in_channels=512, reduction=16, G=8):
super().__init__()
self.G = G
self.channel = in_channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(in_channels // (2 * G), in_channels // (2 * G))
self.cweight = Parameter(torch.zeros(1, in_channels // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, in_channels // (2 * G), 1, 1))
self.sweight = Parameter(torch.zeros(1, in_channels // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, in_channels // (2 * G), 1, 1))
self.sigmoid = nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
b, c, h, w = x.size()
# group into subfeatures
x = x.view(b * self.G, -1, h, w) # bs*G,c//G,h,w
# channel_split
x_0, x_1 = x.chunk(2, dim=1) # bs*G,c//(2*G),h,w
# channel attention
x_channel = self.avg_pool(x_0) # bs*G,c//(2*G),1,1