EMA模块
论文链接:https://arxiv.org/abs/2305.13563v1
将注意力机制模块添加到MMYOLO中
-
将开源代码EMA.py文件复制到mmyolo/models/plugins目录下
-
导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS
-
确保 class EMA 中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)
-
利用@MODELS.register_module()将“class EMA(nn.Module)”注册:
-
修改mmyolo/models/plugins/__init__.py文件
-
在终端运行:
python setup.py install
-
修改对应的配置文件,并且将plugins的参数“type”设置为“EMA”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-优快云博客
修改后的EMA.py
import torch
from torch import nn
from mmyolo.registry import MODELS
@MODELS.register_module()
class EMA(nn.Module):
def __init__(self, in_channels, factor=8):
super(EMA, self).__init__()
self.groups = factor
assert in_channels // self.groups > 0
self.softmax = nn.Softmax(-1)
self.agp = nn.AdaptiveAvgPool2d((1, 1))
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
self.gn = nn.GroupNorm(in_channels // self.groups, in_channels // self.groups)
self.conv1x1 = nn.Conv2d(in_channels // self.groups, in_channels // self.groups, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(in_channels // self.groups, in_channels // self.groups, kernel_size=3, stride=1, padding=1)
def forward(self, x):
b, c, h, w = x.size()
group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w
x_h = self.pool_h(group_x)
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
x_h, x_w = torch.split(hw, [h, w], dim=2)
x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
x2 = self.conv3x3(group_x)
x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c/