import numpy as np
import torch
from torch import nn
from torch.nn import init
# SE-attention
# 方法出处 2018 CVPR 《Squeeze-and-Excitation Networks》
# 该方法用于捕获特征图之间的关系
class SEAttention(nn.Module):
# 模型层的初始化
def __init__(self, channel=512, reduction=16):
# 所有继承于nn.Module的模型都要写这句话
super(SEAttention, self).__init__()
# 这个AdaptiveAvgPool2d会将输入特征图的宽和高
# 自动的池化到我们在AdaptiveAvgPool2d参数中指定的大小
# 比如
# m = nn.AdaptiveAvgPool2d((5, 7))
# input = torch.randn(1, 64, 8, 9)
# output = m(input)
# print(output.size())
# 最后会输出[1,64,5,7]
# 如果指定的最后宽和高相等可以只写一个
# 比如
# nn.AdaptiveAvgPool2d((1,1))==nn.AdaptiveAvgPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 定义线性变换层
# 计算特征图不同通道之间的相关性
self.fc = nn.Sequential(
# 输入的是通道的维度
# 输出的是通道维度缩减之后的维度
# 其中缩减系数reduction由我们自己指定
nn.Linear(channel, channel // reduction, bias=False),
# 激活
# 这个inplace=True表示在前一层输出的结果之上直接进行计算
# 而不再重新开辟内存空间
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
self.init_weights()
# 每层权重的初始化
# 权重初始化可以自己定义
# 但其实PyTorch会给我们搭建的模型自动进行初始化
def init_weights(self):
# 遍历当前模型的每一层
for m in self.modules():
# 如果是卷积层
if isinstance(m, nn.Conv2d):
# kaiming初始化
init.kaiming_normal_(m.weight, mode='fan_out')
# 偏置为0
if m.bias is not None:
init.constant_(m.bias, 0)
# 如果是正则化层
elif isinstance(m, nn.BatchNorm2d):
# 权重为1
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
# 如果是线性层
elif isinstance(m, nn.Linear):
# 以标准差为0.001的正态分布随机初始化
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
# 前向传递建立计算图
def forward(self, x):
# 获取输入特征图的批大小,通道数
b, c, _, _ = x.size()
# 可适应池化
# 这个操作之后
# 相当于每一个h*w的特征图变成了一个数字
# 先将h*w的特征图按列相加
# 然后再按行相加
y = self.avg_pool(x).view(b, c)
# 计算不同通道之间的相关权重
# 将权重转换为[b,c,1,1]维度
y = self.fc(y).view(b, c, 1, 1)
# 给每一个通道的特征加权
# expand_as将y的维度变为和x相同的维度
# 相当于Numpy的广播机制
# 因为PyTorch底层依赖Numpy
# PyTorch自带广播机制
# 其实不写这个函数也行
return x * y.expand_as(x)
if __name__ == '__main__':
# 可以将input看作是某一个卷积层输出的特征图
# 维度是[50,512,7,7]
# 代表批大小是50
# 通道数512
# 宽,高7*7
input = torch.randn(50, 512, 7, 7)
# 计算各个通道特征之间的相关性
se = SEAttention(channel=512, reduction=8)
output = se(input)
print(output.shape)
2018 CVPR 《Squeeze-and-Excitation Networks》 PyTorch实现
最新推荐文章于 2024-04-19 04:15:00 发布