一、SENet
- 对特征层进行全局平均池化(高度和宽度上),获得一个深度为C2的特征长条
- 通过两个全连接层,第一个全连接层输出缩减维度/r,第二个全连接层输出还原维度
- 最后通过一个Sigmoid函数,获取每个特征层的权值 0—1之间
- 对原特征层的每一层进行加权并返回
二、SENet的pytorch实现
class senet(nn.Module):
def __init__(self, channel, ratio=16):
super(senet, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) #全局平均池化
self.fc = nn.Sequential(
nn.Linear(channel, channel // ratio, False),
nn.ReLU(),
nn.Linear(channel // ratio, channel, False),
nn.Sigmoid(), # 用于将每个元素压缩到0---1之间
)
def forward(self, x):
b, c, h, w = x.size()
# 全局平均池化 b,c,h,w----->b,c,1,1
avg = self.avg_pool(x).view([b, c])
# 将平均池化后的结果输入全连接 通过sigmoid得到每一个层的权重大小
fc = self.fc(avg).view([b, c, 1, 1])
return x * fc # 将特征层 和 每一层相乘加权
三、学到的一些其他知识
在继承nn.Sequential的类中,不需要显式地重写forward方法
因为当调用nn.Sequential的实例的forward方法时,它会按照模块的顺序依次调用每个模块的forward方法,将输入数据传递给每个模块,并返回最终的输出。
import torch
import torch.nn as nn
class MyNetwork(nn.Sequential):
def __init__(self):
super(MyNetwork, self).__init__(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 30),
nn.ReLU()
)
# 使用自定义网络类
model = MyNetwork()
input = torch.randn(1, 10)
output = model(input)