pytorch代码实现注意力机制之SK Attention

本文介绍了一种在卷积神经网络中引入的SK注意力机制,通过动态调整神经元的感受野大小,以适应不同尺度的输入信息。实验结果表明,SKNet在保持较低模型复杂性的同时,优于现有最先进的架构在ImageNet和CIFAR数据集上表现优秀。

SK注意力机制

简介:在标准卷积神经网络 (CNN) 中,每一层人工神经元的感受野被设计为共享相同的大小。在神经科学界众所周知,视觉皮层神经元的感受野大小受到刺激的调节,这在构建 CNN 时很少被考虑。我们在CNN中提出了一种动态选择机制,该机制允许每个神经元根据输入信息的多个尺度自适应地调整其感受野大小。设计了一个称为选择性内核 (SK) 单元的构建块,其中使用由这些分支中的信息引导的 softmax 注意力融合具有不同内核大小的多个分支。对这些分支的不同关注会产生不同大小的融合层神经元的有效感受野。多个 SK 单元堆叠到称为选择性内核网络 (SKNets) 的深度网络中。在ImageNet和CIFAR基准测试中,我们通过经验表明,SKNet以较低的模型复杂性优于现有的最先进的架构。详细分析表明,SKNet中的神经元可以捕获不同尺度的目标对象,这验证了神经元根据输入自适应调整其感受野大小的能力。

结构图

代码实现

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict


class SKAttention(nn.Module):

    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        self.d = max(L, channel // reduction)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
                    ('bn', nn.BatchNorm2d(channel)),
                    ('relu', nn.ReLU())
                ]))
            )
        self.fc = nn.Linear(channel, self.d)
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, channel))
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs = []
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats = torch.stack(conv_outs, 0)  # k,bs,channel,h,w

        ### fuse
        U = sum(conv_outs)  # bs,c,h,w

        ### reduction channel
        S = U.mean(-1).mean(-1)  # bs,c
        Z = self.fc(S)  # bs,d

        ### calculate attention weight
        weights = []
        for fc in self.fcs:
            weight = fc(Z)
            weights.append(weight.view(bs, c, 1, 1))  # bs,channel
        attention_weughts = torch.stack(weights, 0)  # k,bs,channel,1,1
        attention_weughts = self.softmax(attention_weughts)  # k,bs,channel,1,1

        ### fuse
        V = (attention_weughts * feats).sum(0)
        return V


if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    se = SKAttention(channel=512, reduction=8)
    output = se(input)
    print(output.shape)
SK Attention是一种注意力机制,它在图像处理中被广泛应用。SK Attention借鉴了SENet的思想,通过动态计算每个卷积核得到通道的权重,然后动态地将各个卷积核的结果进行融合。这种注意力机制可以让网络更加关注待检测目标,从而提高检测效果。 以下是SK Attention的示例代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class SKConv(nn.Module): def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32): super(SKConv, self).__init__() d = max(in_channels // r, L) self.M = M self.out_channels = out_channels self.conv = nn.ModuleList() for i in range(M): self.conv.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1 + i, dilation=1 + i, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )) self.global_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(out_channels, d, bias=False), nn.ReLU(inplace=True), nn.Linear(d, out_channels * M, bias=False) ) self.softmax = nn.Softmax(dim=1) def forward(self, x): batch_size = x.size(0) output = [] for i, conv in enumerate(self.conv): output.append(conv(x)) U = sum(output) s = self.global_pool(U).view(batch_size, -1) z = self.fc(s).view(batch_size, self.M, self.out_channels) a_b = self.softmax(z) a_b = list(a_b.chunk(self.M, dim=1)) V = sum([a * b for a, b in zip(output, a_b)]) return V # 使用SK Attention sk_conv = SKConv(in_channels=64, out_channels=128) input = torch.randn(1, 64, 32, 32) output = sk_conv(input) print(output.shape) # 输出:torch.Size([1, 128, 32, 32]) ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

「已注销」

你的激励是我肝下去的动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值