GatherExcite(GE)模块
论文链接:https://arxiv.org/pdf/1709.01507v4.pdf
将GatherExcite(GE)模块添加到MMYOLO中
-
将开源代码GE.py文件复制到mmyolo/models/plugins目录下
-
导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS
-
确保 class GatherExcite中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)
-
利用@MODELS.register_module()将“class GatherExcite(nn.Module)”注册:
-
修改mmyolo/models/plugins/__init__.py文件
-
在终端运行:
python setup.py install
-
修改对应的配置文件,并且将plugins的参数“type”设置为“GatherExcite”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-优快云博客
修改后的GE.py
import math
import torch
import torch.nn.functional as F
from timm.models.layers.create_act import create_act_layer, get_act_layer
from timm.models.layers.create_conv2d import create_conv2d
from timm.models.layers.helpers import make_divisible
from timm.models.layers.mlp import ConvMlp
from torch import nn as nn
from mmyolo.registry import MODELS
@MODELS.register_module()
class GatherExcite(nn.Module):
def __init__(
self, in_channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
super(GatherExcite, self).__init__()
self.add_maxpool = add_maxpool
act_layer = get_act_layer(act_layer)
self.extent = extent
if extra_params:
self.gather = nn.Sequential()
if extent == 0:
assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
self.gather.add_module(
'conv1', create_conv2d(in_channels, in_channels, kernel_size=feat_size, stride=1, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm1', nn.BatchNorm2d(in_channels))
else:
assert extent % 2 == 0
num_conv = int(math.log2(extent))
for i in range(num_conv):
self.gather.add_module(
f'conv{i + 1}',
create_conv2d(in_channels, in_channels, kernel_size=3, stride=2, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(in_channels))
if i != num_conv - 1:
self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
else:
self.gather = None
if self.extent == 0:
self.gk