SKAttention模块
论文链接:https://arxiv.org/pdf/1903.06586.pdf
将SKAttention模块添加到MMYOLO中
-
将开源代码SK.py文件复制到mmyolo/models/plugins目录下
-
导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS
-
确保 class SKAttention中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)
-
利用@MODELS.register_module()将“class SKAttention(nn.Module)”注册:
-
修改mmyolo/models/plugins/__init__.py文件
-
在终端运行:
python setup.py install
-
修改对应的配置文件,并且将plugins的参数“type”设置为“BiLevelRoutingAttention”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-优快云博客
修改后的SK.py
from collections import OrderedDict
import torch
from torch import nn
from mmyolo.registry import MODELS
@MODELS.register_module()
class SKAttention(nn.Module):
def __init__(self, in_channels=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
super().__init__()
self.d = max(L, in_channels // reduction)
self.convs = nn.ModuleList([])
for k in kernels:
self.convs.append(
nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_channels, in_channels, kernel_size=k, padding=k // 2, groups=group)),
('bn', nn.BatchNorm2d(in_channels)),
('relu', nn.ReLU())
]))
)
self.fc = nn.Linear(in_channels, self.d)
self.fcs = nn.ModuleList([])
for i in range(len(kernels)):
self.fcs.append(nn.Linear(self.d, in_channels))
self.softmax = nn.Softmax(dim=0)
def forward(self, x):
bs, c, _, _ = x.size()
conv_outs = []
### split
for conv in self.convs:
conv_outs.append(conv(x))
feats = torch.stack(conv_outs, 0) # k,bs,channel,h,w
### fuse
U = sum(conv_outs) # bs,c,h,w
### reduction channel
S = U.mean(-1).mean(-1) # bs,c
Z = self.fc(S) # bs,d
### calculate attention weight
weights = []
for fc in self.fcs:
weight = fc(Z)
weights.append(weight.view(bs, c, 1, 1)) # bs,channel
attention_weughts = torch.stack(weights, 0) # k,bs,channel,1,1
attention_weughts = self.softmax(attention_weughts) # k,bs,channel,1,1
### fuse
V = (attention_weughts * feats).sum(0)
return V
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
se = SKAttention