import torch
import torch.nn as nn
from functools import partial
import math
from timm.models.layers import trunc_normal_tf_
from timm.models.helpers import named_apply
__all__ = ['MSConv']
def gcd(a, b):
while b:
a, b = b, a % b
return a
# Other types of layers can go here (e.g., nn.Linear, etc.)
def _init_weights(module, name, scheme=''):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d):
if scheme == 'normal':
nn.init.normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'trunc_normal':
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'xavier_normal':
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# efficientnet like
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer
act = act.lower()
if act == 'relu':
layer = nn.ReLU(inplace)
elif act == 'relu6':
layer = nn.ReLU6(inplace)
elif act == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'gelu':
layer = nn.GELU()
elif act == 'hswish':
layer = nn.Hardswish(inplace)
else:
raise NotImplementedError('activation layer [%s] is not found' % act)
return layer
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
# Multi-scale depth-wise convolution (MSDC)
class MSDC(nn.Module):
def __init__(self, in_channels, kernel_sizes, stride, activation='relu6', dw_parallel=True):
super(MSDC, self).__init__()
self.in_channels = in_channels
self.kernel_sizes = kernel_sizes
self.activation = activation
self.dw_parallel = dw_parallel
self.dwconvs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(self.in_channels, self.in_channels, kernel_size, stride, kernel_size // 2,
groups=self.in_channels, bias=False),
nn.BatchNorm2d(self.in_channels),
act_layer(self.activation, inplace=True)
)
for kernel_size in self.kernel_sizes
])
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
# Apply the convolution layers in a loop
outputs = []
for dwconv in self.dwconvs:
dw_out = dwconv(x)
outputs.append(dw_out)
if self.dw_parallel == False:
x = x + dw_out
# You can return outputs based on what you intend to do with them
return outputs
class MSCB(nn.Module):
"""
Multi-scale convolution block (MSCB)
"""
def __init__(self, in_channels, out_channels, shortcut=False, stride=1, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True, activation='relu6'):
super(MSCB, self).__init__()
add = shortcut
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_sizes = kernel_sizes
self.expansion_factor = expansion_factor
self.dw_parallel = dw_parallel
self.add = add
self.activation = activation
self.n_scales = len(self.kernel_sizes)
# check stride value
assert self.stride in [1, 2]
# Skip connection if stride is 1
self.use_skip_connection = True if self.stride == 1 else False
# expansion factor
self.ex_channels = int(self.in_channels * self.expansion_factor)
self.pconv1 = nn.Sequential(
# pointwise convolution
nn.Conv2d(self.in_channels, self.ex_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.ex_channels),
act_layer(self.activation, inplace=True)
)
self.msdc = MSDC(self.ex_channels, self.kernel_sizes, self.stride, self.activation,
dw_parallel=self.dw_parallel)
if self.add == True:
self.combined_channels = self.ex_channels * 1
else:
self.combined_channels = self.ex_channels * self.n_scales
self.pconv2 = nn.Sequential(
# pointwise convolution
nn.Conv2d(self.combined_channels, self.out_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.out_channels),
)
if self.use_skip_connection and (self.in_channels != self.out_channels):
self.conv1x1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0, bias=False)
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
pout1 = self.pconv1(x)
msdc_outs = self.msdc(pout1)
if self.add == True:
dout = 0
for dwout in msdc_outs:
dout = dout + dwout
else:
dout = torch.cat(msdc_outs, dim=1)
dout = channel_shuffle(dout, gcd(self.combined_channels, self.out_channels))
out = self.pconv2(dout)
if self.use_skip_connection:
if self.in_channels != self.out_channels:
x = self.conv1x1(x)
return x + out
else:
return out
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class MSConv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
# Replace Conv2d with MSDC (from MSCB)
# Note: MSDC requires kernel_sizes parameter, so we'll use the provided k as the only kernel size
# Also, MSDC is depthwise by design (groups=in_channels), so we need to handle cases where g != 1
# For simplicity, we'll use MSDC only when g=1 (standard convolution case)
# For grouped convolutions (g>1), we'll fall back to regular Conv2d
if g == 1: # Standard convolution case - use MSDC
self.conv = nn.Sequential(
MSCB(c1, c2, shortcut=False, stride=s,
kernel_sizes=[k], # Using the provided kernel size
expansion_factor=1, # No expansion
dw_parallel=True,
activation='relu6'),
nn.BatchNorm2d(c2)
)
else: # Grouped convolution case - fall back to regular Conv2d
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
怎样改进,使得MSConv的模块精度提升?
最新发布