import torch
import torch.nn as nn
class simam_module(torch.nn.Module):
def __init__(self, channels = None, e_lambda = 1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2) # (b,c,h,w) - (b,c,1,1) -> (b,c,h,w)
# (b,c,h,w) / 4*((b,c,1,1)/n+lambda) + 0.5
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
08-15
1648

03-31
1163

07-25
5604

08-02
2734

07-25
2万+

07-22
3万+

03-20
3569
